diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..39e7ae7fd0fdd2d8e5bc370225bb1f3eb8648ac8 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8a0162bfc076e35c9b4d87579f05f86ff2639a43 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.venv +__pycache__ +.bak +megablocks-moe/.bak +.pytest_cache \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bf9f2ddd7ef78ebf080e4e4cefae031874c39667 --- /dev/null +++ b/README.md @@ -0,0 +1,6 @@ +--- +license: apache-2.0 +tags: + - kernel +--- + diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..9a7f0f7f0f98c3f6daff64e90d92ab1c6132c096 --- /dev/null +++ b/build.toml @@ -0,0 +1,30 @@ +[general] +name = "megablocks" +universal = false + +[torch] +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h" +] + +[kernel.megablocks] +backend = "cuda" +src = [ + "csrc/new_cumsum.h", + "csrc/new_cumsum.cu", + "csrc/new_histogram.h", + "csrc/new_histogram.cu", + "csrc/new_indices.h", + "csrc/new_indices.cu", + "csrc/new_replicate.cu", + "csrc/new_replicate.h", + "csrc/new_sort.h", + "csrc/new_sort.cu", +] +depends = [ "torch", "cutlass_3_8" ] + +[test] +python-git-packages = [ + { url = "https://github.com/stanford-futuredata/stk.git", rev = "7363137", sha256 = "0m6g5l9nlwaiwybg5j8dhnz159wdpabdnkzapnn3dsifxrsb59vz" } +] \ No newline at end of file diff --git a/csrc/bak.ops.cu b/csrc/bak.ops.cu new file mode 100644 index 0000000000000000000000000000000000000000..50884d8a88b473942abeb87d50b28d2f9d6e8025 --- /dev/null +++ b/csrc/bak.ops.cu @@ -0,0 +1,21 @@ +#include "cumsum.h" +#include "histogram.h" +#include "indices.h" +#include "replicate.h" +#include "sort.h" + +#include + +namespace megablocks { + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum."); + m.def("histogram", &histogram, "even width histogram."); + m.def("inclusive_cumsum", &inclusive_cumsum, "batched inclusive cumsum"); + m.def("indices", &indices, "indices construction for sparse matrix."); + m.def("replicate_forward", &replicate_forward, "(fwd) replicate a vector dynamically."); + m.def("replicate_backward", &replicate_backward, "(bwd) replicate a vector dynamically."); + m.def("sort", &sort, "key/value sort."); +} + +} // namespace megablocks diff --git a/csrc/cuda_util.h b/csrc/cuda_util.h new file mode 100644 index 0000000000000000000000000000000000000000..66e1d5911a79390c1a1065702b71118e4c20b9f3 --- /dev/null +++ b/csrc/cuda_util.h @@ -0,0 +1,62 @@ +#ifndef BLOCKPARTY_CSRC_CUDA_UTIL_H_ +#define BLOCKPARTY_CSRC_CUDA_UTIL_H_ + +#include +#include +// #include + +namespace megablocks { + +typedef __half2 half2; + +struct __align__(8) half4 { + half2 x, y; +}; + +struct __align__(16) half8 { + half2 x, y, z, w; +}; + +template +__device__ __forceinline__ To BitCast(const From& src) noexcept { + To dst; + std::memcpy(&dst, &src, sizeof(To)); + return dst; +} + +template +__device__ __forceinline__ void Store(const T& value, T* ptr) { + *ptr = value; +} + +template +__device__ __forceinline__ T Load(const T* address) { + return __ldg(address); +} + +__device__ __forceinline__ half4 Load(const half4* address) { + float2 x = __ldg(reinterpret_cast(address)); + return BitCast(x); +} + +__device__ __forceinline__ half8 Load(const half8* address) { + float4 x = __ldg(reinterpret_cast(address)); + return BitCast(x); +} + +template +__device__ __forceinline__ T Zero() { return 0; }; + +template <> +__device__ __forceinline__ half2 Zero() { + return {(c10::Half)0., (c10::Half)0.}; +}; + +template <> +__device__ __forceinline__ half4 Zero() { + return {Zero(), Zero()}; +}; + +} // namespace megablocks + +#endif // BLOCKPARTY_CSRC_CUDA_UTIL_H_ diff --git a/csrc/cumsum.h b/csrc/cumsum.h new file mode 100644 index 0000000000000000000000000000000000000000..e4db9676e82285d1f1177ac3fa071f349117bc2a --- /dev/null +++ b/csrc/cumsum.h @@ -0,0 +1,163 @@ +#define CUB_IGNORE_DEPRECATED_API + +#undef CUB_WRAPPED_NAMESPACE +#define CUB_WRAPPED_NAMESPACE megablocks + +#include + +#include +#include +#include +// #include + +#define CUDA_CALL(code) \ + do { \ + cudaError_t status = code; \ + std::string err = cudaGetErrorString(status); \ + TORCH_CHECK(status == cudaSuccess, err); \ + } while (0) + +namespace megablocks { + +struct Inclusive {}; +struct Exclusive {}; + +template struct Cumsum { + + template< + typename InputIteratorT, + typename OutputIteratorT> + static void Run(void * d_temp_storage, + size_t & temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) { + CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_items, + stream));//, + //debug_synchronous)); + } +}; + +template <> struct Cumsum { + template< + typename InputIteratorT, + typename OutputIteratorT> + static void Run(void * d_temp_storage, + size_t & temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) { + CUDA_CALL(cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_items, + stream));//, + //debug_synchronous)); + } +}; + +template +void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) { + // Get temporary storage size. + size_t scratchpad_bytes = 0; + Cumsum::Run(nullptr, + scratchpad_bytes, + x.data_ptr(), + out.data_ptr(), + x.size(1), + c10::cuda::getCurrentCUDAStream()); + + // Allocate scratchpad. + // + // NOTE: We scale for the batch dimension so we can run in parallel. + auto options = torch::TensorOptions() + .dtype(torch::kInt8) + .device(x.device()); + torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0), + options); + + // Run the kernel. + // + // NOTE: Using different streams for each issue does not appear to + // yield performance gains for our problem set. The overhead of + // event/stream synchronization appears to outweigh the benfits. + // We could write a true batched cumsum, but this would require + // significant code duplication from cub and we might move away + // from this formulation anyways. + for (int i = 0; i < x.size(0); ++i) { + void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i; + Cumsum::Run(scratchpad_ptr, + scratchpad_bytes, + x.data_ptr() + x.size(1) * i, + out.data_ptr() + x.size(1) * i, + x.size(1), + c10::cuda::getCurrentCUDAStream()); + } +} + +void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) { + // Validate the input matrix. + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.ndimension() == 2); + TORCH_CHECK(x.scalar_type() == torch::kInt16 || + x.scalar_type() == torch::kInt32 || + x.scalar_type() == torch::kInt64); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.ndimension() == 2); + TORCH_CHECK(out.scalar_type() == x.scalar_type()); + + // NOTE: We currently only support contraction across the contiguous + // dimension in the matrix. + TORCH_CHECK(dim == 1); + + switch (x.scalar_type()) { + case torch::kInt16: + cub_cumsum(x, dim, out); + return; + case torch::kInt32: + cub_cumsum(x, dim, out); + return; + } + TORCH_CHECK(x.scalar_type() == torch::kInt64); + cub_cumsum(x, dim, out); +} + +void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) { + // Validate the input matrix. + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.ndimension() == 2); + TORCH_CHECK(x.scalar_type() == torch::kInt16 || + x.scalar_type() == torch::kInt32 || + x.scalar_type() == torch::kInt64); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.ndimension() == 2); + TORCH_CHECK(out.scalar_type() == x.scalar_type()); + + // NOTE: We currently only support contraction across the contiguous + // dimension in the matrix. + TORCH_CHECK(dim == 1); + + switch (x.scalar_type()) { + case torch::kInt16: + cub_cumsum(x, dim, out); + return; + case torch::kInt32: + cub_cumsum(x, dim, out); + return; + } + TORCH_CHECK(x.scalar_type() == torch::kInt64); + cub_cumsum(x, dim, out); +} + +} // namespace megablocks + +#undef CUB_WRAPPED_NAMESPACE \ No newline at end of file diff --git a/csrc/histogram.h b/csrc/histogram.h new file mode 100644 index 0000000000000000000000000000000000000000..161115b82fa37b86fadbf19dd64d773ff43d0d1a --- /dev/null +++ b/csrc/histogram.h @@ -0,0 +1,86 @@ +#undef CUB_WRAPPED_NAMESPACE +#define CUB_WRAPPED_NAMESPACE megablocks + +#include + +#include +#include +// #include + +#define CUDA_CALL(code) \ + do { \ + cudaError_t status = code; \ + std::string err = cudaGetErrorString(status); \ + TORCH_CHECK(status == cudaSuccess, err); \ + } while (0) + +namespace megablocks { + +template +torch::Tensor cub_histogram(torch::Tensor x, int num_bins) { + // Allocate the count buffer. + auto options = torch::TensorOptions() + .dtype(torch::kInt32) + .device(x.device()); + torch::Tensor out = torch::empty({x.size(0), num_bins}, options); + + // Exit early if there is not work to do. + if (out.numel() == 0) return out; + + // Get scratchpad size. + size_t scratchpad_bytes = 0; + CUDA_CALL(cub::DeviceHistogram::HistogramEven(nullptr, + scratchpad_bytes, + x.data_ptr(), + out.data_ptr(), + /*num_levels=*/num_bins + 1, + /*lower_level=*/0, + /*upper_level=*/num_bins, + /*num_samples=*/int(x.size(1)), + c10::cuda::getCurrentCUDAStream())); + + // Allocate scratchpad. + options = torch::TensorOptions().dtype(torch::kInt8).device(x.device()); + torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options); + + // Run the kernel. + for (int i = 0; i < x.size(0); ++i) { + CUDA_CALL(cub::DeviceHistogram::HistogramEven(scratchpad.data_ptr(), + scratchpad_bytes, + x.data_ptr() + x.size(1) * i, + out.data_ptr() + out.size(1) * i, + /*num_levels=*/num_bins + 1, + /*lower_level=*/0, + /*upper_level=*/num_bins, + /*num_samples=*/int(x.size(1)), + c10::cuda::getCurrentCUDAStream())); + } + return out; +} + +torch::Tensor histogram(torch::Tensor x, int num_bins) { + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.ndimension() == 1 || x.ndimension() == 2); + TORCH_CHECK(x.scalar_type() == torch::kInt16 || + x.scalar_type() == torch::kInt32 || + x.scalar_type() == torch::kInt64); + bool no_batch = x.ndimension() == 1; + if (no_batch) x = x.view({1, x.numel()}); + + if (x.scalar_type() == torch::kInt16) { + auto out = cub_histogram(x, num_bins); + return no_batch ? out.flatten() : out; + } else if (x.scalar_type() == torch::kInt32) { + auto out = cub_histogram(x, num_bins); + return no_batch ? out.flatten() : out; + } else { + TORCH_CHECK(x.scalar_type() == torch::kInt64); + auto out = cub_histogram(x, num_bins); + return no_batch ? out.flatten() : out; + } +} + +} // namespace megablocks + +#undef CUDA_CALL +#undef CUB_WRAPPED_NAMESPACE diff --git a/csrc/indices.h b/csrc/indices.h new file mode 100644 index 0000000000000000000000000000000000000000..a2f0d2d6a90e46e2468221f253eb787105aa2b93 --- /dev/null +++ b/csrc/indices.h @@ -0,0 +1,95 @@ +#include +#include +// #include +#include + +#define CUDA_CALL(code) \ + do { \ + cudaError_t status = code; \ + std::string err = cudaGetErrorString(status); \ + TORCH_CHECK(status == cudaSuccess, err); \ + } while (0) + +namespace megablocks { +namespace construct_indices { + +// We expect the number of outputs per block to be small. For +// example, with ffn_hidden_size=4096, we only need to write +// 32 elements per block per iteration. +const int kThreadsPerBlock = 32; + +__global__ void __launch_bounds__(kThreadsPerBlock) + ConstructIndicesKernel(short * __restrict__ indices, + int num_columns, + int block_size, + const int * __restrict__ padded_bins) { + // Load the offset for this bins indices. + int start = 0; + if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1); + int end = __ldg(padded_bins + blockIdx.x); + + // Divide the start and end into blocks. + start /= block_size; + end /= block_size; + + // Offset the output buffer to the start of the bin. + indices += (start + blockIdx.y) * num_columns + threadIdx.x; + + // Write the indices to the output. + int bin_offset = blockIdx.y; + int num_rows = end - start; + for (; bin_offset < num_rows; num_rows -= gridDim.y) { + short *out = indices; + for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) { + *out = bid + (blockIdx.x * num_columns); + out += kThreadsPerBlock; + } + indices += gridDim.y * num_columns; + } +} + +cudaError_t ConstructIndices(short * __restrict__ indices, + int output_block_rows, + int output_block_columns, + int block_size, + const int * __restrict__ padded_bins, + int num_bins, + cudaStream_t stream) { + dim3 block_dim(kThreadsPerBlock); + dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins)); + ConstructIndicesKernel<<>>(indices, + output_block_columns, + block_size, + padded_bins); + return cudaGetLastError(); +} + +} // namespace construct_indices + +void indices(torch::Tensor padded_bins, + int block_size, + int output_block_rows, + int output_block_columns, + torch::Tensor out) { + TORCH_CHECK(padded_bins.is_cuda()); + TORCH_CHECK(padded_bins.ndimension() == 1); + TORCH_CHECK(padded_bins.scalar_type() == torch::kInt); + + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.ndimension() == 1); + TORCH_CHECK(out.scalar_type() == torch::kInt16); + TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns)); + + // Exit early if there is no work to do. + if (out.numel() == 0) return; + + CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr(), + output_block_rows, + output_block_columns, + block_size, + padded_bins.data_ptr(), + padded_bins.numel(), + c10::cuda::getCurrentCUDAStream())); +} + +} // namespace megablocks diff --git a/csrc/new_cumsum.cu b/csrc/new_cumsum.cu new file mode 100644 index 0000000000000000000000000000000000000000..175fdc91779435355d138ffb87dd0c244d0e2da8 --- /dev/null +++ b/csrc/new_cumsum.cu @@ -0,0 +1,161 @@ +#define CUB_IGNORE_DEPRECATED_API + +#undef CUB_WRAPPED_NAMESPACE +#define CUB_WRAPPED_NAMESPACE megablocks + +#include "new_cumsum.h" +#include +#include +#include + +#define CUDA_CALL(code) \ + do { \ + cudaError_t status = code; \ + std::string err = cudaGetErrorString(status); \ + TORCH_CHECK(status == cudaSuccess, err); \ + } while (0) + +namespace megablocks { + +struct Inclusive {}; +struct Exclusive {}; + +template struct Cumsum { + + template< + typename InputIteratorT, + typename OutputIteratorT> + static void Run(void * d_temp_storage, + size_t & temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) { + CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_items, + stream));//, + //debug_synchronous)); + } +}; + +template <> struct Cumsum { + template< + typename InputIteratorT, + typename OutputIteratorT> + static void Run(void * d_temp_storage, + size_t & temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_items, + cudaStream_t stream = 0, + bool debug_synchronous = false) { + CUDA_CALL(cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_items, + stream));//, + //debug_synchronous)); + } +}; + +template +void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) { + // Get temporary storage size. + size_t scratchpad_bytes = 0; + Cumsum::Run(nullptr, + scratchpad_bytes, + x.data_ptr(), + out.data_ptr(), + x.size(1), + c10::cuda::getCurrentCUDAStream()); + + // Allocate scratchpad. + // + // NOTE: We scale for the batch dimension so we can run in parallel. + auto options = torch::TensorOptions() + .dtype(torch::kInt8) + .device(x.device()); + torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0), + options); + + // Run the kernel. + // + // NOTE: Using different streams for each issue does not appear to + // yield performance gains for our problem set. The overhead of + // event/stream synchronization appears to outweigh the benfits. + // We could write a true batched cumsum, but this would require + // significant code duplication from cub and we might move away + // from this formulation anyways. + for (int i = 0; i < x.size(0); ++i) { + void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i; + Cumsum::Run(scratchpad_ptr, + scratchpad_bytes, + x.data_ptr() + x.size(1) * i, + out.data_ptr() + x.size(1) * i, + x.size(1), + c10::cuda::getCurrentCUDAStream()); + } +} + +void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) { + // Validate the input matrix. + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.ndimension() == 2); + TORCH_CHECK(x.scalar_type() == torch::kInt16 || + x.scalar_type() == torch::kInt32 || + x.scalar_type() == torch::kInt64); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.ndimension() == 2); + TORCH_CHECK(out.scalar_type() == x.scalar_type()); + + // NOTE: We currently only support contraction across the contiguous + // dimension in the matrix. + TORCH_CHECK(dim == 1); + + switch (x.scalar_type()) { + case torch::kInt16: + cub_cumsum(x, dim, out); + return; + case torch::kInt32: + cub_cumsum(x, dim, out); + return; + } + TORCH_CHECK(x.scalar_type() == torch::kInt64); + cub_cumsum(x, dim, out); +} + +void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) { + // Validate the input matrix. + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.ndimension() == 2); + TORCH_CHECK(x.scalar_type() == torch::kInt16 || + x.scalar_type() == torch::kInt32 || + x.scalar_type() == torch::kInt64); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.ndimension() == 2); + TORCH_CHECK(out.scalar_type() == x.scalar_type()); + + // NOTE: We currently only support contraction across the contiguous + // dimension in the matrix. + TORCH_CHECK(dim == 1); + + switch (x.scalar_type()) { + case torch::kInt16: + cub_cumsum(x, dim, out); + return; + case torch::kInt32: + cub_cumsum(x, dim, out); + return; + } + TORCH_CHECK(x.scalar_type() == torch::kInt64); + cub_cumsum(x, dim, out); +} + +} // namespace megablocks + +#undef CUB_WRAPPED_NAMESPACE \ No newline at end of file diff --git a/csrc/new_cumsum.h b/csrc/new_cumsum.h new file mode 100644 index 0000000000000000000000000000000000000000..b18f282398b3bb34577a8c1faf97455bd504ceb5 --- /dev/null +++ b/csrc/new_cumsum.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace megablocks { + +// Forward declarations for the public interface functions +void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out); +void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out); + +} // namespace megablocks \ No newline at end of file diff --git a/csrc/new_histogram.cu b/csrc/new_histogram.cu new file mode 100644 index 0000000000000000000000000000000000000000..cabd6a7a014407e23a4286fa48bd43341987f1c6 --- /dev/null +++ b/csrc/new_histogram.cu @@ -0,0 +1,85 @@ +#undef CUB_WRAPPED_NAMESPACE +#define CUB_WRAPPED_NAMESPACE megablocks + +#include "new_histogram.h" +#include +#include +#include + +#define CUDA_CALL(code) \ + do { \ + cudaError_t status = code; \ + std::string err = cudaGetErrorString(status); \ + TORCH_CHECK(status == cudaSuccess, err); \ + } while (0) + +namespace megablocks { + +template +torch::Tensor cub_histogram(torch::Tensor x, int num_bins) { + // Allocate the count buffer. + auto options = torch::TensorOptions() + .dtype(torch::kInt32) + .device(x.device()); + torch::Tensor out = torch::empty({x.size(0), num_bins}, options); + + // Exit early if there is not work to do. + if (out.numel() == 0) return out; + + // Get scratchpad size. + size_t scratchpad_bytes = 0; + CUDA_CALL(cub::DeviceHistogram::HistogramEven(nullptr, + scratchpad_bytes, + x.data_ptr(), + out.data_ptr(), + /*num_levels=*/num_bins + 1, + /*lower_level=*/0, + /*upper_level=*/num_bins, + /*num_samples=*/int(x.size(1)), + c10::cuda::getCurrentCUDAStream())); + + // Allocate scratchpad. + options = torch::TensorOptions().dtype(torch::kInt8).device(x.device()); + torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options); + + // Run the kernel. + for (int i = 0; i < x.size(0); ++i) { + CUDA_CALL(cub::DeviceHistogram::HistogramEven(scratchpad.data_ptr(), + scratchpad_bytes, + x.data_ptr() + x.size(1) * i, + out.data_ptr() + out.size(1) * i, + /*num_levels=*/num_bins + 1, + /*lower_level=*/0, + /*upper_level=*/num_bins, + /*num_samples=*/int(x.size(1)), + c10::cuda::getCurrentCUDAStream())); + } + return out; +} + +torch::Tensor histogram(torch::Tensor x, int num_bins) { + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.ndimension() == 1 || x.ndimension() == 2); + TORCH_CHECK(x.scalar_type() == torch::kInt16 || + x.scalar_type() == torch::kInt32 || + x.scalar_type() == torch::kInt64); + bool no_batch = x.ndimension() == 1; + if (no_batch) x = x.view({1, x.numel()}); + + if (x.scalar_type() == torch::kInt16) { + auto out = cub_histogram(x, num_bins); + return no_batch ? out.flatten() : out; + } else if (x.scalar_type() == torch::kInt32) { + auto out = cub_histogram(x, num_bins); + return no_batch ? out.flatten() : out; + } else { + TORCH_CHECK(x.scalar_type() == torch::kInt64); + auto out = cub_histogram(x, num_bins); + return no_batch ? out.flatten() : out; + } +} + +} // namespace megablocks + +#undef CUDA_CALL +#undef CUB_WRAPPED_NAMESPACE \ No newline at end of file diff --git a/csrc/new_histogram.h b/csrc/new_histogram.h new file mode 100644 index 0000000000000000000000000000000000000000..8d742dd8dc3ef2c33a660032f700746fb56de828 --- /dev/null +++ b/csrc/new_histogram.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace megablocks { + +// Public interface function for computing histograms +torch::Tensor histogram(torch::Tensor x, int num_bins); + +} // namespace megablocks \ No newline at end of file diff --git a/csrc/new_indices.cu b/csrc/new_indices.cu new file mode 100644 index 0000000000000000000000000000000000000000..c1eef972c90868e17403542d173985a802b20308 --- /dev/null +++ b/csrc/new_indices.cu @@ -0,0 +1,97 @@ +#include "new_indices.h" +#include +#include +#include + +#define CUDA_CALL(code) \ + do { \ + cudaError_t status = code; \ + std::string err = cudaGetErrorString(status); \ + TORCH_CHECK(status == cudaSuccess, err); \ + } while (0) + +namespace megablocks { +namespace construct_indices { + +// We expect the number of outputs per block to be small. For +// example, with ffn_hidden_size=4096, we only need to write +// 32 elements per block per iteration. +const int kThreadsPerBlock = 32; + +__global__ void __launch_bounds__(kThreadsPerBlock) + ConstructIndicesKernel(short * __restrict__ indices, + int num_columns, + int block_size, + const int * __restrict__ padded_bins) { + // Load the offset for this bins indices. + int start = 0; + if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1); + int end = __ldg(padded_bins + blockIdx.x); + + // Divide the start and end into blocks. + start /= block_size; + end /= block_size; + + // Offset the output buffer to the start of the bin. + indices += (start + blockIdx.y) * num_columns + threadIdx.x; + + // Write the indices to the output. + int bin_offset = blockIdx.y; + int num_rows = end - start; + for (; bin_offset < num_rows; num_rows -= gridDim.y) { + short *out = indices; + for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) { + *out = bid + (blockIdx.x * num_columns); + out += kThreadsPerBlock; + } + indices += gridDim.y * num_columns; + } +} + +cudaError_t ConstructIndices(short * __restrict__ indices, + int output_block_rows, + int output_block_columns, + int block_size, + const int * __restrict__ padded_bins, + int num_bins, + cudaStream_t stream) { + dim3 block_dim(kThreadsPerBlock); + dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins)); + ConstructIndicesKernel<<>>(indices, + output_block_columns, + block_size, + padded_bins); + return cudaGetLastError(); +} + +} // namespace construct_indices + +void indices(torch::Tensor padded_bins, + int block_size, + int output_block_rows, + int output_block_columns, + torch::Tensor out) { + TORCH_CHECK(padded_bins.is_cuda()); + TORCH_CHECK(padded_bins.ndimension() == 1); + TORCH_CHECK(padded_bins.scalar_type() == torch::kInt); + + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.ndimension() == 1); + TORCH_CHECK(out.scalar_type() == torch::kInt16); + TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns)); + + // Exit early if there is no work to do. + if (out.numel() == 0) return; + + CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr(), + output_block_rows, + output_block_columns, + block_size, + padded_bins.data_ptr(), + padded_bins.numel(), + c10::cuda::getCurrentCUDAStream())); +} + +} // namespace megablocks + +#undef CUDA_CALL \ No newline at end of file diff --git a/csrc/new_indices.h b/csrc/new_indices.h new file mode 100644 index 0000000000000000000000000000000000000000..8744303e52ee1e238cd7c903d3e968a0656d9422 --- /dev/null +++ b/csrc/new_indices.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace megablocks { + +// Public interface function for constructing indices from padded bins +void indices(torch::Tensor padded_bins, + int block_size, + int output_block_rows, + int output_block_columns, + torch::Tensor out); + +} // namespace megablocks \ No newline at end of file diff --git a/csrc/new_replicate.cu b/csrc/new_replicate.cu new file mode 100644 index 0000000000000000000000000000000000000000..db2a450ea7ecf8bcc4a3c6d3676ee4bd387d3d3a --- /dev/null +++ b/csrc/new_replicate.cu @@ -0,0 +1,210 @@ +#undef CUB_WRAPPED_NAMESPACE +#define CUB_WRAPPED_NAMESPACE megablocks + +#include "new_replicate.h" +#include +#include +#include +#include + +#define CUDA_CALL(code) \ + do { \ + cudaError_t status = code; \ + std::string err = cudaGetErrorString(status); \ + TORCH_CHECK(status == cudaSuccess, err); \ + } while (0) + +namespace megablocks { +namespace replicate { + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + ReplicateForwardKernel(T * __restrict__ x, + int * __restrict__ bins, + T * __restrict__ out, + int columns) { + // Offset to this threadblocks batch. + // + // x is [batch_size, num_bins] + // out is [batch_size, columns] + // bins is [num_bins] + int batch_idx = blockIdx.y; + int num_bins = gridDim.x; + x += batch_idx * num_bins; + out += batch_idx * columns; + + // Load the start/end for this bin. + int bin_idx = blockIdx.x; + int start = 0; + if (bin_idx > 0) start = __ldg(bins + bin_idx - 1); + int end = __ldg(bins + bin_idx); + + // Load the value to replicate. + T value = __ldg((T*)x + bin_idx); + + // Offset to this threadblocks bin and this threads + // offset within the bin. + int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x; + out += start + bin_offset; + + // Replicate the value to the output. + // + // TODO(tgale): Vectorize these stores. + int num_elements = end - start; + const int kElementsPerLoop = gridDim.z * kThreadsPerBlock; + T *out_ptr = (T*)out; + for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) { + *out_ptr = value; + out_ptr += kElementsPerLoop; + } +} + +template +cudaError_t ReplicateForward(T *x, + int batch_size, + int num_bins, + int *bins, + T *out, + int columns, + cudaStream_t stream) { + const int kThreadsPerBlock = 64; + dim3 block_dim(kThreadsPerBlock, 1, 1); + int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock)); + dim3 grid_dim(num_bins, batch_size, group_size); + ReplicateForwardKernel<<< + grid_dim, block_dim, 0, stream>>>(x, bins, out, columns); + return cudaGetLastError(); +} + +void cub_segmented_reduce(torch::Tensor grad, + torch::Tensor bins, + torch::Tensor out, + cudaStream_t stream) { + // Append a zero to the bin boundaries for CUB. + torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options()); + CUDA_CALL(cudaMemsetAsync(offsets.data_ptr(), + 0, + offsets.numel() * sizeof(int), + stream)); + CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr() + 1, + bins.data_ptr(), + bins.numel() * sizeof(int), + cudaMemcpyDeviceToDevice, + stream)); + + // Get temporary buffer size. + size_t scratchpad_bytes = 0; + CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr, + scratchpad_bytes, + grad.data_ptr(), + out.data_ptr(), + bins.numel(), + offsets.data_ptr(), + offsets.data_ptr() + 1, + stream)); + + // Allocate scratchpad. + auto options = torch::TensorOptions() + .dtype(torch::kInt8) + .device(grad.device()); + torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options); + + // Run the kernel for each batch item. + for (int i = 0; i < grad.size(0); ++i) { + int num_bins = out.size(1); + int num_values = grad.size(1); + CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr(), + scratchpad_bytes, + grad.data_ptr() + i * num_values, + out.data_ptr() + i * num_bins, + bins.numel(), + offsets.data_ptr(), + offsets.data_ptr() + 1, + stream)); + } +} + +} // namespace replicate + +void replicate_forward(torch::Tensor x, + torch::Tensor bins, + torch::Tensor out) { + // Validate the inputs. + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.ndimension() == 2); + TORCH_CHECK(x.scalar_type() == torch::kFloat16 || + x.scalar_type() == torch::kInt16 || + x.scalar_type() == torch::kInt32); + TORCH_CHECK(bins.is_cuda()); + TORCH_CHECK(bins.ndimension() == 1); + TORCH_CHECK(bins.scalar_type() == torch::kInt); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.ndimension() == 2); + TORCH_CHECK(out.scalar_type() == x.scalar_type()); + + // Batch dimensions should match for input/output. + TORCH_CHECK(x.size(0) == out.size(0)); + + // One input for each bin (in each batch). + TORCH_CHECK(x.size(1) == bins.size(0)); + + // Exit early if there is no work to do. + if (out.numel() == 0) return; + + switch (x.scalar_type()) { + case torch::kFloat16: + CUDA_CALL(replicate::ReplicateForward(x.data_ptr(), + x.size(0), + x.size(1), + bins.data_ptr(), + out.data_ptr(), + out.size(1), + c10::cuda::getCurrentCUDAStream())); + return; + case torch::kInt32: + CUDA_CALL(replicate::ReplicateForward(x.data_ptr(), + x.size(0), + x.size(1), + bins.data_ptr(), + out.data_ptr(), + out.size(1), + c10::cuda::getCurrentCUDAStream())); + return; + } + TORCH_CHECK(x.scalar_type() == torch::kInt16); + CUDA_CALL(replicate::ReplicateForward(x.data_ptr(), + x.size(0), + x.size(1), + bins.data_ptr(), + out.data_ptr(), + out.size(1), + c10::cuda::getCurrentCUDAStream())); +} + +void replicate_backward(torch::Tensor grad, + torch::Tensor bins, + torch::Tensor out) { + // Validate the inputs. + TORCH_CHECK(grad.is_cuda()); + TORCH_CHECK(grad.ndimension() == 2); + TORCH_CHECK(grad.scalar_type() == torch::kFloat16); + TORCH_CHECK(bins.is_cuda()); + TORCH_CHECK(bins.ndimension() == 1); + TORCH_CHECK(bins.scalar_type() == torch::kInt); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.ndimension() == 2); + TORCH_CHECK(out.scalar_type() == torch::kFloat16); + + // Batch dimensions should match for input/output. + TORCH_CHECK(grad.size(0) == out.size(0)); + + // One output for each bin (in each batch). + TORCH_CHECK(out.size(1) == bins.size(0)); + + replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream()); +} + +} // namespace megablocks + +#undef CUDA_CALL +#undef CUB_WRAPPED_NAMESPACE \ No newline at end of file diff --git a/csrc/new_replicate.h b/csrc/new_replicate.h new file mode 100644 index 0000000000000000000000000000000000000000..2edb8c4c7415099c354963407dd8a5f6a1f1dc7f --- /dev/null +++ b/csrc/new_replicate.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace megablocks { + +// Forward pass: replicate values from x according to bin sizes +void replicate_forward(torch::Tensor x, + torch::Tensor bins, + torch::Tensor out); + +// Backward pass: reduce gradients back to bins using segmented reduction +void replicate_backward(torch::Tensor grad, + torch::Tensor bins, + torch::Tensor out); + +} // namespace megablocks \ No newline at end of file diff --git a/csrc/new_sort.cu b/csrc/new_sort.cu new file mode 100644 index 0000000000000000000000000000000000000000..08a7a05566bac78bcbfafc0f18dc842a6983fd46 --- /dev/null +++ b/csrc/new_sort.cu @@ -0,0 +1,90 @@ +#undef CUB_WRAPPED_NAMESPACE +#define CUB_WRAPPED_NAMESPACE megablocks + +#include "new_sort.h" +#include +#include +#include + +#define CUDA_CALL(code) \ + do { \ + cudaError_t status = code; \ + std::string err = cudaGetErrorString(status); \ + TORCH_CHECK(status == cudaSuccess, err); \ + } while (0) + +namespace megablocks { + +template +void cub_radix_sort(torch::Tensor x, + int end_bit, + torch::Tensor x_out, + torch::Tensor iota_out) { + // Get iota for values in sort. + torch::Tensor iota = torch::arange(0, x.numel(), x.options()); + + // Get temporary buffer size. + size_t scratchpad_bytes = 0; + CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, + scratchpad_bytes, + x.data_ptr(), + x_out.data_ptr(), + iota.data_ptr(), + iota_out.data_ptr(), + x.numel(), + /*begin_bit*/0, + /*end_bit=*/end_bit, + c10::cuda::getCurrentCUDAStream())); + + // Allocate scratchpad. + auto options = torch::TensorOptions() + .dtype(torch::kInt8) + .device(x.device()); + torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options); + + // Run the kernel. + CUDA_CALL(cub::DeviceRadixSort::SortPairs(scratchpad.data_ptr(), + scratchpad_bytes, + x.data_ptr(), + x_out.data_ptr(), + iota.data_ptr(), + iota_out.data_ptr(), + x.numel(), + /*begin_bit=*/0, + /*end_bit=*/end_bit, + c10::cuda::getCurrentCUDAStream())); +} + +void sort(torch::Tensor x, + int end_bit, + torch::Tensor x_out, + torch::Tensor iota_out) { + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.ndimension() == 1); + TORCH_CHECK(x.scalar_type() == torch::kInt16 || + x.scalar_type() == torch::kInt32 || + x.scalar_type() == torch::kInt64); + TORCH_CHECK(x_out.is_cuda()); + TORCH_CHECK(x_out.ndimension() == 1); + TORCH_CHECK(x_out.scalar_type() == x.scalar_type()); + TORCH_CHECK(iota_out.is_cuda()); + TORCH_CHECK(iota_out.ndimension() == 1); + TORCH_CHECK(iota_out.scalar_type() == x.scalar_type()); + + // Exit early if there is not work to do. + if (x_out.numel() == 0) return; + + switch (x.scalar_type()) { + case torch::kInt16: + return cub_radix_sort(x, end_bit, x_out, iota_out); + case torch::kInt32: + return cub_radix_sort(x, end_bit, x_out, iota_out); + } + TORCH_CHECK(x.scalar_type() == torch::kInt64); + return cub_radix_sort(x, end_bit, x_out, iota_out); +} + +} // namespace megablocks + +#undef CUDA_CALL +#undef CUB_WRAPPED_NAMESPACE \ No newline at end of file diff --git a/csrc/new_sort.h b/csrc/new_sort.h new file mode 100644 index 0000000000000000000000000000000000000000..2fa05fccc3c7d4f1581c73c9bd5bd48d10705f84 --- /dev/null +++ b/csrc/new_sort.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace megablocks { + +// Public interface function for radix sorting with indices +void sort(torch::Tensor x, + int end_bit, + torch::Tensor x_out, + torch::Tensor iota_out); + +} // namespace megablocks \ No newline at end of file diff --git a/csrc/replicate.h b/csrc/replicate.h new file mode 100644 index 0000000000000000000000000000000000000000..5c857e68a7ad54d3bb26b6c307b6babd879f71bd --- /dev/null +++ b/csrc/replicate.h @@ -0,0 +1,211 @@ +#undef CUB_WRAPPED_NAMESPACE +#define CUB_WRAPPED_NAMESPACE megablocks + +#include + +#include +#include +#include +// #include + +#define CUDA_CALL(code) \ + do { \ + cudaError_t status = code; \ + std::string err = cudaGetErrorString(status); \ + TORCH_CHECK(status == cudaSuccess, err); \ + } while (0) + +namespace megablocks { +namespace replicate { + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + ReplicateForwardKernel(T * __restrict__ x, + int * __restrict__ bins, + T * __restrict__ out, + int columns) { + // Offset to this threadblocks batch. + // + // x is [batch_size, num_bins] + // out is [batch_size, columns] + // bins is [num_bins] + int batch_idx = blockIdx.y; + int num_bins = gridDim.x; + x += batch_idx * num_bins; + out += batch_idx * columns; + + // Load the start/end for this bin. + int bin_idx = blockIdx.x; + int start = 0; + if (bin_idx > 0) start = __ldg(bins + bin_idx - 1); + int end = __ldg(bins + bin_idx); + + // Load the value to replicate. + T value = __ldg((T*)x + bin_idx); + + // Offset to this threadblocks bin and this threads + // offset within the bin. + int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x; + out += start + bin_offset; + + // Replicate the value to the output. + // + // TODO(tgale): Vectorize these stores. + int num_elements = end - start; + const int kElementsPerLoop = gridDim.z * kThreadsPerBlock; + T *out_ptr = (T*)out; + for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) { + *out_ptr = value; + out_ptr += kElementsPerLoop; + } +} + +template +cudaError_t ReplicateForward(T *x, + int batch_size, + int num_bins, + int *bins, + T *out, + int columns, + cudaStream_t stream) { + const int kThreadsPerBlock = 64; + dim3 block_dim(kThreadsPerBlock, 1, 1); + int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock)); + dim3 grid_dim(num_bins, batch_size, group_size); + ReplicateForwardKernel<<< + grid_dim, block_dim, 0, stream>>>(x, bins, out, columns); + return cudaGetLastError(); +} + +void cub_segmented_reduce(torch::Tensor grad, + torch::Tensor bins, + torch::Tensor out, + cudaStream_t stream) { + // Append a zero to the bin boundaries for CUB. + torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options()); + CUDA_CALL(cudaMemsetAsync(offsets.data_ptr(), + 0, + offsets.numel() * sizeof(int), + stream)); + CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr() + 1, + bins.data_ptr(), + bins.numel() * sizeof(int), + cudaMemcpyDeviceToDevice, + stream)); + + // Get temporary buffer size. + size_t scratchpad_bytes = 0; + CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr, + scratchpad_bytes, + grad.data_ptr(), + out.data_ptr(), + bins.numel(), + offsets.data_ptr(), + offsets.data_ptr() + 1, + stream)); + + // Allocate scratchpad. + auto options = torch::TensorOptions() + .dtype(torch::kInt8) + .device(grad.device()); + torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options); + + // Run the kernel for each batch item. + for (int i = 0; i < grad.size(0); ++i) { + int num_bins = out.size(1); + int num_values = grad.size(1); + CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr(), + scratchpad_bytes, + grad.data_ptr() + i * num_values, + out.data_ptr() + i * num_bins, + bins.numel(), + offsets.data_ptr(), + offsets.data_ptr() + 1, + stream)); + } +} + +} // namespace replicate + +void replicate_forward(torch::Tensor x, + torch::Tensor bins, + torch::Tensor out) { + // Validate the inputs. + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.ndimension() == 2); + TORCH_CHECK(x.scalar_type() == torch::kFloat16 || + x.scalar_type() == torch::kInt16 || + x.scalar_type() == torch::kInt32); + TORCH_CHECK(bins.is_cuda()); + TORCH_CHECK(bins.ndimension() == 1); + TORCH_CHECK(bins.scalar_type() == torch::kInt); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.ndimension() == 2); + TORCH_CHECK(out.scalar_type() == x.scalar_type()); + + // Batch dimensions should match for input/output. + TORCH_CHECK(x.size(0) == out.size(0)); + + // One input for each bin (in each batch). + TORCH_CHECK(x.size(1) == bins.size(0)); + + // Exit early if there is no work to do. + if (out.numel() == 0) return; + + switch (x.scalar_type()) { + case torch::kFloat16: + CUDA_CALL(replicate::ReplicateForward(x.data_ptr(), + x.size(0), + x.size(1), + bins.data_ptr(), + out.data_ptr(), + out.size(1), + c10::cuda::getCurrentCUDAStream())); + return; + case torch::kInt32: + CUDA_CALL(replicate::ReplicateForward(x.data_ptr(), + x.size(0), + x.size(1), + bins.data_ptr(), + out.data_ptr(), + out.size(1), + c10::cuda::getCurrentCUDAStream())); + return; + } + TORCH_CHECK(x.scalar_type() == torch::kInt16); + CUDA_CALL(replicate::ReplicateForward(x.data_ptr(), + x.size(0), + x.size(1), + bins.data_ptr(), + out.data_ptr(), + out.size(1), + c10::cuda::getCurrentCUDAStream())); +} + +void replicate_backward(torch::Tensor grad, + torch::Tensor bins, + torch::Tensor out) { + // Validate the inputs. + TORCH_CHECK(grad.is_cuda()); + TORCH_CHECK(grad.ndimension() == 2); + TORCH_CHECK(grad.scalar_type() == torch::kFloat16); + TORCH_CHECK(bins.is_cuda()); + TORCH_CHECK(bins.ndimension() == 1); + TORCH_CHECK(bins.scalar_type() == torch::kInt); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.ndimension() == 2); + TORCH_CHECK(out.scalar_type() == torch::kFloat16); + + // Batch dimensions should match for input/output. + TORCH_CHECK(grad.size(0) == out.size(0)); + + // One output for each bin (in each batch). + TORCH_CHECK(out.size(1) == bins.size(0)); + + replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream()); +} + +} // namespace megablocks + +#undef CUDA_CALL +#undef CUB_WRAPPED_NAMESPACE diff --git a/csrc/sort.h b/csrc/sort.h new file mode 100644 index 0000000000000000000000000000000000000000..251da182b3b488bcc87e8bb05c7400f14650ee61 --- /dev/null +++ b/csrc/sort.h @@ -0,0 +1,91 @@ +#undef CUB_WRAPPED_NAMESPACE +#define CUB_WRAPPED_NAMESPACE megablocks + +#include + +#include +#include +// #include + +#define CUDA_CALL(code) \ + do { \ + cudaError_t status = code; \ + std::string err = cudaGetErrorString(status); \ + TORCH_CHECK(status == cudaSuccess, err); \ + } while (0) + +namespace megablocks { + +template +void cub_radix_sort(torch::Tensor x, + int end_bit, + torch::Tensor x_out, + torch::Tensor iota_out) { + // Get iota for values in sort. + torch::Tensor iota = torch::arange(0, x.numel(), x.options()); + + // Get temporary buffer size. + size_t scratchpad_bytes = 0; + CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, + scratchpad_bytes, + x.data_ptr(), + x_out.data_ptr(), + iota.data_ptr(), + iota_out.data_ptr(), + x.numel(), + /*begin_bit*/0, + /*end_bit=*/end_bit, + c10::cuda::getCurrentCUDAStream())); + + // Allocate scratchpad. + auto options = torch::TensorOptions() + .dtype(torch::kInt8) + .device(x.device()); + torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options); + + // Run the kernel. + CUDA_CALL(cub::DeviceRadixSort::SortPairs(scratchpad.data_ptr(), + scratchpad_bytes, + x.data_ptr(), + x_out.data_ptr(), + iota.data_ptr(), + iota_out.data_ptr(), + x.numel(), + /*begin_bit=*/0, + /*end_bit=*/end_bit, + c10::cuda::getCurrentCUDAStream())); +} + +void sort(torch::Tensor x, + int end_bit, + torch::Tensor x_out, + torch::Tensor iota_out) { + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.ndimension() == 1); + TORCH_CHECK(x.scalar_type() == torch::kInt16 || + x.scalar_type() == torch::kInt32 || + x.scalar_type() == torch::kInt64); + TORCH_CHECK(x_out.is_cuda()); + TORCH_CHECK(x_out.ndimension() == 1); + TORCH_CHECK(x_out.scalar_type() == x.scalar_type()); + TORCH_CHECK(iota_out.is_cuda()); + TORCH_CHECK(iota_out.ndimension() == 1); + TORCH_CHECK(iota_out.scalar_type() == x.scalar_type()); + + // Exit early if there is not work to do. + if (x_out.numel() == 0) return; + + switch (x.scalar_type()) { + case torch::kInt16: + return cub_radix_sort(x, end_bit, x_out, iota_out); + case torch::kInt32: + return cub_radix_sort(x, end_bit, x_out, iota_out); + } + TORCH_CHECK(x.scalar_type() == torch::kInt64); + return cub_radix_sort(x, end_bit, x_out, iota_out); +} + +} // namespace megablocks + +#undef CUDA_CALL +#undef CUB_WRAPPED_NAMESPACE diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000000000000000000000000000000000..956c21dcb9602669b8a55b867d7b10f1c29f28fd --- /dev/null +++ b/flake.lock @@ -0,0 +1,164 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1733328505, + "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1748598786, + "owner": "huggingface", + "repo": "hf-nix", + "rev": "6ca679441494139fde1f2355691ddb5dc8170269", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1749576434, + "narHash": "sha256-wSdtZih2fMQ3ne/U7OKIhmP43zCIuRBhJ5zMMz747u0=", + "path": "/home/ubuntu/Projects/kernel-builder", + "type": "path" + }, + "original": { + "path": "/home/ubuntu/Projects/kernel-builder", + "type": "path" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1747820358, + "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=", + "owner": "danieldk", + "repo": "nixpkgs", + "rev": "d3c1681180717528068082103bf323147de6ab0b", + "type": "github" + }, + "original": { + "owner": "danieldk", + "ref": "cudatoolkit-12.9-kernel-builder", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..bd568e23d91202dd6b6af3655384b28ee94f0f15 --- /dev/null +++ b/flake.nix @@ -0,0 +1,18 @@ +{ + description = "Flake for megablocks_moe kernel"; + + inputs = { + kernel-builder.url = "path:/home/ubuntu/Projects/kernel-builder"; + # kernel-builder.url = "github:huggingface/kernel-builder/v0.4.0"; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + path = ./.; + rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + }; +} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_mb_moe.py b/tests/test_mb_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..ff6f5b22370273e90e8acc8a2bc5b56196d96f06 --- /dev/null +++ b/tests/test_mb_moe.py @@ -0,0 +1,6 @@ +import megablocks + +def test_import(): + """Simple test to check if the module can be imported.""" + print("megablocks_moe module imported successfully.") + print("Available functions:", dir(megablocks)) diff --git a/torch-ext/megablocks/__init__.py b/torch-ext/megablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5dc7204fde87d8dc125bd8e9f0afc641820de52 --- /dev/null +++ b/torch-ext/megablocks/__init__.py @@ -0,0 +1,191 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.exclusive_cumsum(x, dim, out) + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.inclusive_cumsum(x, dim, out) + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/torch-ext/megablocks/_version.py b/torch-ext/megablocks/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279 --- /dev/null +++ b/torch-ext/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.11.0.dev0' diff --git a/torch-ext/megablocks/backend/__init__.py b/torch-ext/megablocks/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/torch-ext/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/torch-ext/megablocks/backend/kernels.py b/torch-ext/megablocks/backend/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/torch-ext/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/torch-ext/megablocks/bak.__init__.py b/torch-ext/megablocks/bak.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/torch-ext/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/torch-ext/megablocks/benchmark_util.py b/torch-ext/megablocks/benchmark_util.py new file mode 100644 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/torch-ext/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/torch-ext/megablocks/grouped_gemm_util.py b/torch-ext/megablocks/grouped_gemm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3f977f360fd0ad5800c3b5da9ce57be794b9b8 --- /dev/null +++ b/torch-ext/megablocks/grouped_gemm_util.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + import grouped_gemm + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +backend = grouped_gemm.backend if grouped_gemm_is_available() else None +ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/torch-ext/megablocks/layers/__init__.py b/torch-ext/megablocks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..849b023b1a0f765fb1e26addf4670dc6db785a52 --- /dev/null +++ b/torch-ext/megablocks/layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/torch-ext/megablocks/layers/activation_fn.py b/torch-ext/megablocks/layers/activation_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..a31770ba179fed06abf2da10102ccaeed1d3ee4e --- /dev/null +++ b/torch-ext/megablocks/layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/torch-ext/megablocks/layers/all_to_all.py b/torch-ext/megablocks/layers/all_to_all.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/torch-ext/megablocks/layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/torch-ext/megablocks/layers/arguments.py b/torch-ext/megablocks/layers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..3962c771c90012535aab058443f89c541a2e9236 --- /dev/null +++ b/torch-ext/megablocks/layers/arguments.py @@ -0,0 +1,100 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +import megablocks.grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/torch-ext/megablocks/layers/common.py b/torch-ext/megablocks/layers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927 --- /dev/null +++ b/torch-ext/megablocks/layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/torch-ext/megablocks/layers/dmlp_registry.py b/torch-ext/megablocks/layers/dmlp_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d765bd04387a29bddd2789fae04821635b555e82 --- /dev/null +++ b/torch-ext/megablocks/layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/torch-ext/megablocks/layers/dmoe.py b/torch-ext/megablocks/layers/dmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..205727ff4d63f9e8dc9648acaac99a97f3394d6f --- /dev/null +++ b/torch-ext/megablocks/layers/dmoe.py @@ -0,0 +1,327 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import stk.ops +import torch +from stk import Matrix + +import megablocks.ops as ops +# from megablocks.ops import ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/torch-ext/megablocks/layers/gelu.py b/torch-ext/megablocks/layers/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..40b601d4a8eb59e80e1090f31f4172b8f7fb7549 --- /dev/null +++ b/torch-ext/megablocks/layers/gelu.py @@ -0,0 +1,43 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/torch-ext/megablocks/layers/glu.py b/torch-ext/megablocks/layers/glu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe0c915c307e7f7cade3ea3ff679399635fcd81 --- /dev/null +++ b/torch-ext/megablocks/layers/glu.py @@ -0,0 +1,223 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk.ops +import torch + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/torch-ext/megablocks/layers/memory_test.py b/torch-ext/megablocks/layers/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4acbd94f212ce906ace9de78dd3ffc6afa03f97e --- /dev/null +++ b/torch-ext/megablocks/layers/memory_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +from megablocks.layers import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/torch-ext/megablocks/layers/memory_test.sh b/torch-ext/megablocks/layers/memory_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..acf5704654439b61c6987859e7c3d52a60203fb4 --- /dev/null +++ b/torch-ext/megablocks/layers/memory_test.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +DISTRIBUTED_ARGUMENTS="\ +--nproc_per_node 1 \ +--nnodes 1 \ +--node_rank 0 \ +--master_addr localhost \ +--master_port 6000" + +python -m torch.distributed.launch \ + ${DISTRIBUTED_ARGUMENTS} \ + megablocks/layers/memory_test.py diff --git a/torch-ext/megablocks/layers/mlp.py b/torch-ext/megablocks/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6f4d82441e3a9c185db5bfdf686d53790dde26 --- /dev/null +++ b/torch-ext/megablocks/layers/mlp.py @@ -0,0 +1,574 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import stk +import stk.backend.triton_kernels +import stk.ops +import torch +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/torch-ext/megablocks/layers/moe.py b/torch-ext/megablocks/layers/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e --- /dev/null +++ b/torch-ext/megablocks/layers/moe.py @@ -0,0 +1,475 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + replicate_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/torch-ext/megablocks/layers/mpu.py b/torch-ext/megablocks/layers/mpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b23213902f4567d7fdb0158cbcf5406c2b2aa601 --- /dev/null +++ b/torch-ext/megablocks/layers/mpu.py @@ -0,0 +1,93 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +from megablocks.layers.arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/torch-ext/megablocks/layers/router.py b/torch-ext/megablocks/layers/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9dcd9e433322482a80d5b95afccee5c12368f8 --- /dev/null +++ b/torch-ext/megablocks/layers/router.py @@ -0,0 +1,114 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +from megablocks.layers import common +from megablocks.layers.arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/torch-ext/megablocks/layers/sharedexpert_registry.py b/torch-ext/megablocks/layers/sharedexpert_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62db39a2dda5642f6baa9d9b949f7c31cf6d35 --- /dev/null +++ b/torch-ext/megablocks/layers/sharedexpert_registry.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/torch-ext/megablocks/ops/__init__.py b/torch-ext/megablocks/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c --- /dev/null +++ b/torch-ext/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.ops.binned_gather import binned_gather +from megablocks.ops.binned_scatter import binned_scatter +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum +from megablocks.ops.gather import gather +from megablocks.ops.histogram import histogram +from megablocks.ops.padded_gather import padded_gather +from megablocks.ops.padded_scatter import padded_scatter +from megablocks.ops.repeat import repeat +from megablocks.ops.replicate import replicate +from megablocks.ops.round_up import round_up +from megablocks.ops.scatter import scatter +from megablocks.ops.sort import sort +from megablocks.ops.sum import sum +from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/torch-ext/megablocks/ops/all_to_all_benchmark.py b/torch-ext/megablocks/ops/all_to_all_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..47b95301dbb35b0c3154d1ff2878d20792ce5cb2 --- /dev/null +++ b/torch-ext/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,60 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/torch-ext/megablocks/ops/all_to_all_benchmark.sh b/torch-ext/megablocks/ops/all_to_all_benchmark.sh new file mode 100755 index 0000000000000000000000000000000000000000..b4ff4659c7c698c8672a631c297d434b136becb9 --- /dev/null +++ b/torch-ext/megablocks/ops/all_to_all_benchmark.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +DISTRIBUTED_ARGUMENTS="\ +--nproc_per_node 8 \ +--nnodes 1 \ +--node_rank 0 \ +--master_addr localhost \ +--master_port 6000" + +python -m torch.distributed.launch \ + ${DISTRIBUTED_ARGUMENTS} \ + megablocks/ops/all_to_all_benchmark.py diff --git a/torch-ext/megablocks/ops/binned_gather.py b/torch-ext/megablocks/ops/binned_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..89cce1b627137d7f987baca62fd6d82c5c04659a --- /dev/null +++ b/torch-ext/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/torch-ext/megablocks/ops/binned_scatter.py b/torch-ext/megablocks/ops/binned_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ce0d6f5a890a58b89e6a501216a9822341323f --- /dev/null +++ b/torch-ext/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/torch-ext/megablocks/ops/cumsum.py b/torch-ext/megablocks/ops/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..a1974c56cb72d7caf2cafc7ceb701b5e56a6abc0 --- /dev/null +++ b/torch-ext/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + import megablocks._ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/torch-ext/megablocks/ops/gather.py b/torch-ext/megablocks/ops/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..41b09a1233e8a996ff1062579cd0810d095ad1e6 --- /dev/null +++ b/torch-ext/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/torch-ext/megablocks/ops/histogram.py b/torch-ext/megablocks/ops/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6a36ffe02e140397b63d2eb487fa7746b2324c --- /dev/null +++ b/torch-ext/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks._ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/torch-ext/megablocks/ops/histogram_benchmark.py b/torch-ext/megablocks/ops/histogram_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9de8e65281c829ba63b5c60853dfa5bb0a333988 --- /dev/null +++ b/torch-ext/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/torch-ext/megablocks/ops/matmul_benchmark.py b/torch-ext/megablocks/ops/matmul_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa7b7c2f697f9454c5381e31ee78040be5b229e --- /dev/null +++ b/torch-ext/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,403 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import stk +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/torch-ext/megablocks/ops/padded_gather.py b/torch-ext/megablocks/ops/padded_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..f272a7768dc6468e81bd2cd25294ca16a6826c08 --- /dev/null +++ b/torch-ext/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/torch-ext/megablocks/ops/padded_scatter.py b/torch-ext/megablocks/ops/padded_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff81dd9456a04258346266e9e85871bda56c65b --- /dev/null +++ b/torch-ext/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/torch-ext/megablocks/ops/padded_scatter_benchmark.py b/torch-ext/megablocks/ops/padded_scatter_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..81dde4e4b4c0f550c6027390e8de285fa4d842f7 --- /dev/null +++ b/torch-ext/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/torch-ext/megablocks/ops/permute_benchmark.py b/torch-ext/megablocks/ops/permute_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..837f07e2858d5c65c14e3f262617e7d2985ce901 --- /dev/null +++ b/torch-ext/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/torch-ext/megablocks/ops/repeat.py b/torch-ext/megablocks/ops/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/torch-ext/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/torch-ext/megablocks/ops/replicate.py b/torch-ext/megablocks/ops/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f85f06d14f799b06124ce2e7c297d575c4a261 --- /dev/null +++ b/torch-ext/megablocks/ops/replicate.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + import megablocks._ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/torch-ext/megablocks/ops/round_up.py b/torch-ext/megablocks/ops/round_up.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/torch-ext/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/torch-ext/megablocks/ops/scatter.py b/torch-ext/megablocks/ops/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5aaafc402a651a47595c25dbf71e382903e5022 --- /dev/null +++ b/torch-ext/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/torch-ext/megablocks/ops/sort.py b/torch-ext/megablocks/ops/sort.py new file mode 100644 index 0000000000000000000000000000000000000000..14f8afa04d44148758e7750b75bd7869670d1e87 --- /dev/null +++ b/torch-ext/megablocks/ops/sort.py @@ -0,0 +1,39 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + import megablocks._ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/torch-ext/megablocks/ops/sort_benchmark.py b/torch-ext/megablocks/ops/sort_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec --- /dev/null +++ b/torch-ext/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/torch-ext/megablocks/ops/sum.py b/torch-ext/megablocks/ops/sum.py new file mode 100644 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/torch-ext/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/torch-ext/megablocks/ops/topology.py b/torch-ext/megablocks/ops/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..f914b4cc17f77e0b17caf46cec9d2a4743e223b8 --- /dev/null +++ b/torch-ext/megablocks/ops/topology.py @@ -0,0 +1,46 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + import megablocks._ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/torch-ext/megablocks/py.typed b/torch-ext/megablocks/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..be38676a560827198e1d591a0f7a3c87354df0bd --- /dev/null +++ b/torch-ext/torch_binding.cpp @@ -0,0 +1,106 @@ +#include + +#include "registration.h" +#include "torch_binding.h" + +#include "new_cumsum.h" +#include "new_histogram.h" +#include "new_indices.h" +#include "new_replicate.h" +#include "new_sort.h" + +// void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) { +torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) { + megablocks::exclusive_cumsum(x, dim, out); + return out; +} + +// void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) { +torch::Tensor inclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) { + megablocks::inclusive_cumsum(x, dim, out); + return out; +} + +// torch::Tensor histogram(torch::Tensor x, int num_bins); +torch::Tensor histogram_wrapper(torch::Tensor x, int64_t num_bins) { + return megablocks::histogram(x, num_bins); +} + +// void indices(torch::Tensor padded_bins, +// int block_size, +// int output_block_rows, +// int output_block_columns, +// torch::Tensor out); +torch::Tensor indices_wrapper(torch::Tensor padded_bins, + int64_t block_size, + int64_t output_block_rows, + int64_t output_block_columns) { + torch::Tensor out = torch::empty({output_block_rows * output_block_columns}, torch::kInt16); + megablocks::indices(padded_bins, block_size, output_block_rows, output_block_columns, out); + return out; +} + + + +// // // Forward pass: replicate values from x according to bin sizes +// // void replicate_forward(torch::Tensor x, +// // torch::Tensor bins, +// // torch::Tensor out); +// tensor::Tensor replicate_forward_wrapper(torch::Tensor x, torch::Tensor bins, torch::Tensor out) { +// megablocks::replicate_forward(x, bins, out); +// return out; +// } + +// // Backward pass: reduce gradients back to bins using segmented reduction +// void replicate_backward(torch::Tensor grad, +// torch::Tensor bins, +// torch::Tensor out); +torch::Tensor replicate_backward_wrapper(torch::Tensor grad, torch::Tensor bins, torch::Tensor out) { + megablocks::replicate_backward(grad, bins, out); + return out; +} + +// // Public interface function for radix sorting with indices +// void sort(torch::Tensor x, +// int end_bit, +// torch::Tensor x_out, +// torch::Tensor iota_out); +torch::Tensor sort_wrapper(torch::Tensor x, int64_t end_bit, torch::Tensor x_out, torch::Tensor iota_out) { + megablocks::sort(x, end_bit, x_out, iota_out); + return x_out; +} + +// Reference implementation: +// +// m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum."); +// m.def("histogram", &histogram, "even width histogram."); +// m.def("inclusive_cumsum", &inclusive_cumsum, "batched inclusive cumsum"); +// m.def("indices", &indices, "indices construction for sparse matrix."); +// m.def("replicate_forward", &replicate_forward, "(fwd) replicate a vector dynamically."); +// m.def("replicate_backward", &replicate_backward, "(bwd) replicate a vector dynamically."); +// m.def("sort", &sort, "key/value sort."); + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("exclusive_cumsum(Tensor x, int dim, Tensor(a!) out) -> Tensor(a!)"); + ops.impl("exclusive_cumsum", torch::kCUDA, &exclusive_cumsum_wrapper); + + ops.def("inclusive_cumsum(Tensor x, int dim, Tensor(a!) out) -> Tensor(a!)"); + ops.impl("inclusive_cumsum", torch::kCUDA, &inclusive_cumsum_wrapper); + + ops.def("histogram(Tensor x, int num_bins) -> Tensor"); + ops.impl("histogram", torch::kCUDA, &histogram_wrapper); + + ops.def("indices(Tensor padded_bins, int block_size, int output_block_rows, int output_block_columns) -> Tensor"); + ops.impl("indices", torch::kCUDA, &indices_wrapper); + + // ops.def("replicate_forward(Tensor x, Tensor bins, Tensor(a!) out) -> Tensor(a!)"); + // ops.impl("replicate_forward", torch::kCUDA, &replicate_forward_wrapper); + + ops.def("replicate_backward(Tensor grad, Tensor bins, Tensor(a!) out) -> Tensor(a!)"); + ops.impl("replicate_backward", torch::kCUDA, &replicate_backward_wrapper); + + ops.def("sort(Tensor x, int end_bit, Tensor x_out, Tensor iota_out) -> Tensor(x_out)"); + ops.impl("sort", torch::kCUDA, &sort_wrapper); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..32185bf463aa61183da91ea0dc93d874b1b0990b --- /dev/null +++ b/torch-ext/torch_binding.h @@ -0,0 +1,6 @@ +#pragma once + +#include + +torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out); +// torch::Tensor inclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out); \ No newline at end of file