diff --git a/.gitattributes b/.gitattributes index 39e7ae7fd0fdd2d8e5bc370225bb1f3eb8648ac8..0cd58331b2a989b68be4ec5676383437fca8687b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,4 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.xz filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text \ No newline at end of file +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.so filter=lfs diff=lfs merge=lfs -text diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d8a2495c4baab89e5f90d84fa92f406a29ea10 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,195 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..50fe814b70e1ecd024b8ad5bda99feb8bab489aa --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36a918d61308a0acbc880516a42fcc2c4dc393f25655326e52128465c9501709 +size 10456376 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd42432c910375fb1713b3bf29145e48b556498 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_359242d +ops = torch.ops._megablocks_359242d + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_359242d::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py new file mode 100644 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3f977f360fd0ad5800c3b5da9ce57be794b9b8 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + import grouped_gemm + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +backend = grouped_gemm.backend if grouped_gemm_is_available() else None +ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..849b023b1a0f765fb1e26addf4670dc6db785a52 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..a31770ba179fed06abf2da10102ccaeed1d3ee4e --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..3962c771c90012535aab058443f89c541a2e9236 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py @@ -0,0 +1,100 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +import megablocks.grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d765bd04387a29bddd2789fae04821635b555e82 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..205727ff4d63f9e8dc9648acaac99a97f3394d6f --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py @@ -0,0 +1,327 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import stk.ops +import torch +from stk import Matrix + +import megablocks.ops as ops +# from megablocks.ops import ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..40b601d4a8eb59e80e1090f31f4172b8f7fb7549 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py @@ -0,0 +1,43 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe0c915c307e7f7cade3ea3ff679399635fcd81 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py @@ -0,0 +1,223 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk.ops +import torch + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4acbd94f212ce906ace9de78dd3ffc6afa03f97e --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +from megablocks.layers import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6f4d82441e3a9c185db5bfdf686d53790dde26 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py @@ -0,0 +1,574 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import stk +import stk.backend.triton_kernels +import stk.ops +import torch +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py @@ -0,0 +1,475 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + replicate_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b23213902f4567d7fdb0158cbcf5406c2b2aa601 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py @@ -0,0 +1,93 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +from megablocks.layers.arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9dcd9e433322482a80d5b95afccee5c12368f8 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py @@ -0,0 +1,114 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +from megablocks.layers import common +from megablocks.layers.arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62db39a2dda5642f6baa9d9b949f7c31cf6d35 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.ops.binned_gather import binned_gather +from megablocks.ops.binned_scatter import binned_scatter +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum +from megablocks.ops.gather import gather +from megablocks.ops.histogram import histogram +from megablocks.ops.padded_gather import padded_gather +from megablocks.ops.padded_scatter import padded_scatter +from megablocks.ops.repeat import repeat +from megablocks.ops.replicate import replicate +from megablocks.ops.round_up import round_up +from megablocks.ops.scatter import scatter +from megablocks.ops.sort import sort +from megablocks.ops.sum import sum +from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..47b95301dbb35b0c3154d1ff2878d20792ce5cb2 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,60 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..89cce1b627137d7f987baca62fd6d82c5c04659a --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ce0d6f5a890a58b89e6a501216a9822341323f --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7c5137c900d66a959a76fa681436362a4df906 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..41b09a1233e8a996ff1062579cd0810d095ad1e6 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..77e06a29163aa335dc61e8c50932d5f83322ae1d --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9de8e65281c829ba63b5c60853dfa5bb0a333988 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa7b7c2f697f9454c5381e31ee78040be5b229e --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,403 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import stk +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..f272a7768dc6468e81bd2cd25294ca16a6826c08 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff81dd9456a04258346266e9e85871bda56c65b --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..81dde4e4b4c0f550c6027390e8de285fa4d842f7 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..837f07e2858d5c65c14e3f262617e7d2985ce901 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..e570e22093339d1b71dc830edf656897c67f40e7 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5aaafc402a651a47595c25dbf71e382903e5022 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a4553323d8de71a1d5f1821bac50249c68197 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py new file mode 100644 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a693a82a36c052e51afb4a02077e7849a4abf1 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d8a2495c4baab89e5f90d84fa92f406a29ea10 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,195 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..519759fd66838a696df3771c278cb787e9c11dfe --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12cf047c4bcb5f368f490ada3dafcecd1242a443ff5ded94c09d62548922098c +size 11795992 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd42432c910375fb1713b3bf29145e48b556498 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_359242d +ops = torch.ops._megablocks_359242d + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_359242d::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_version.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/bak.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/benchmark_util.py new file mode 100644 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3f977f360fd0ad5800c3b5da9ce57be794b9b8 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + import grouped_gemm + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +backend = grouped_gemm.backend if grouped_gemm_is_available() else None +ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..849b023b1a0f765fb1e26addf4670dc6db785a52 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/activation_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..a31770ba179fed06abf2da10102ccaeed1d3ee4e --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/all_to_all.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..3962c771c90012535aab058443f89c541a2e9236 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py @@ -0,0 +1,100 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +import megablocks.grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d765bd04387a29bddd2789fae04821635b555e82 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..205727ff4d63f9e8dc9648acaac99a97f3394d6f --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py @@ -0,0 +1,327 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import stk.ops +import torch +from stk import Matrix + +import megablocks.ops as ops +# from megablocks.ops import ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..40b601d4a8eb59e80e1090f31f4172b8f7fb7549 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/gelu.py @@ -0,0 +1,43 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe0c915c307e7f7cade3ea3ff679399635fcd81 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py @@ -0,0 +1,223 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk.ops +import torch + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4acbd94f212ce906ace9de78dd3ffc6afa03f97e --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +from megablocks.layers import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6f4d82441e3a9c185db5bfdf686d53790dde26 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py @@ -0,0 +1,574 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import stk +import stk.backend.triton_kernels +import stk.ops +import torch +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py @@ -0,0 +1,475 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + replicate_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b23213902f4567d7fdb0158cbcf5406c2b2aa601 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py @@ -0,0 +1,93 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +from megablocks.layers.arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9dcd9e433322482a80d5b95afccee5c12368f8 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py @@ -0,0 +1,114 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +from megablocks.layers import common +from megablocks.layers.arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62db39a2dda5642f6baa9d9b949f7c31cf6d35 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.ops.binned_gather import binned_gather +from megablocks.ops.binned_scatter import binned_scatter +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum +from megablocks.ops.gather import gather +from megablocks.ops.histogram import histogram +from megablocks.ops.padded_gather import padded_gather +from megablocks.ops.padded_scatter import padded_scatter +from megablocks.ops.repeat import repeat +from megablocks.ops.replicate import replicate +from megablocks.ops.round_up import round_up +from megablocks.ops.scatter import scatter +from megablocks.ops.sort import sort +from megablocks.ops.sum import sum +from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..47b95301dbb35b0c3154d1ff2878d20792ce5cb2 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,60 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..89cce1b627137d7f987baca62fd6d82c5c04659a --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ce0d6f5a890a58b89e6a501216a9822341323f --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7c5137c900d66a959a76fa681436362a4df906 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..41b09a1233e8a996ff1062579cd0810d095ad1e6 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..77e06a29163aa335dc61e8c50932d5f83322ae1d --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9de8e65281c829ba63b5c60853dfa5bb0a333988 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa7b7c2f697f9454c5381e31ee78040be5b229e --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,403 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import stk +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..f272a7768dc6468e81bd2cd25294ca16a6826c08 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff81dd9456a04258346266e9e85871bda56c65b --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..81dde4e4b4c0f550c6027390e8de285fa4d842f7 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..837f07e2858d5c65c14e3f262617e7d2985ce901 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..e570e22093339d1b71dc830edf656897c67f40e7 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/round_up.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5aaafc402a651a47595c25dbf71e382903e5022 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a4553323d8de71a1d5f1821bac50249c68197 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sum.py new file mode 100644 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a693a82a36c052e51afb4a02077e7849a4abf1 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d8a2495c4baab89e5f90d84fa92f406a29ea10 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,195 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..73be500db9e30d7b8ad95919d224ae58090f5a70 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38fcc4266dd94ee3f307bebd1264a1faaf91c73dc680adb72cd268748130b10f +size 11835888 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd42432c910375fb1713b3bf29145e48b556498 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_359242d +ops = torch.ops._megablocks_359242d + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_359242d::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_version.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py new file mode 100644 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3f977f360fd0ad5800c3b5da9ce57be794b9b8 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + import grouped_gemm + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +backend = grouped_gemm.backend if grouped_gemm_is_available() else None +ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..849b023b1a0f765fb1e26addf4670dc6db785a52 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/activation_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..a31770ba179fed06abf2da10102ccaeed1d3ee4e --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/all_to_all.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..3962c771c90012535aab058443f89c541a2e9236 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py @@ -0,0 +1,100 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +import megablocks.grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d765bd04387a29bddd2789fae04821635b555e82 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..205727ff4d63f9e8dc9648acaac99a97f3394d6f --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py @@ -0,0 +1,327 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import stk.ops +import torch +from stk import Matrix + +import megablocks.ops as ops +# from megablocks.ops import ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..40b601d4a8eb59e80e1090f31f4172b8f7fb7549 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/gelu.py @@ -0,0 +1,43 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe0c915c307e7f7cade3ea3ff679399635fcd81 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py @@ -0,0 +1,223 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk.ops +import torch + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4acbd94f212ce906ace9de78dd3ffc6afa03f97e --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +from megablocks.layers import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6f4d82441e3a9c185db5bfdf686d53790dde26 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py @@ -0,0 +1,574 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import stk +import stk.backend.triton_kernels +import stk.ops +import torch +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py @@ -0,0 +1,475 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + replicate_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b23213902f4567d7fdb0158cbcf5406c2b2aa601 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py @@ -0,0 +1,93 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +from megablocks.layers.arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9dcd9e433322482a80d5b95afccee5c12368f8 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/router.py @@ -0,0 +1,114 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +from megablocks.layers import common +from megablocks.layers.arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62db39a2dda5642f6baa9d9b949f7c31cf6d35 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.ops.binned_gather import binned_gather +from megablocks.ops.binned_scatter import binned_scatter +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum +from megablocks.ops.gather import gather +from megablocks.ops.histogram import histogram +from megablocks.ops.padded_gather import padded_gather +from megablocks.ops.padded_scatter import padded_scatter +from megablocks.ops.repeat import repeat +from megablocks.ops.replicate import replicate +from megablocks.ops.round_up import round_up +from megablocks.ops.scatter import scatter +from megablocks.ops.sort import sort +from megablocks.ops.sum import sum +from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..47b95301dbb35b0c3154d1ff2878d20792ce5cb2 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,60 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..89cce1b627137d7f987baca62fd6d82c5c04659a --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ce0d6f5a890a58b89e6a501216a9822341323f --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7c5137c900d66a959a76fa681436362a4df906 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..41b09a1233e8a996ff1062579cd0810d095ad1e6 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..77e06a29163aa335dc61e8c50932d5f83322ae1d --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9de8e65281c829ba63b5c60853dfa5bb0a333988 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa7b7c2f697f9454c5381e31ee78040be5b229e --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,403 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import stk +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..f272a7768dc6468e81bd2cd25294ca16a6826c08 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff81dd9456a04258346266e9e85871bda56c65b --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..81dde4e4b4c0f550c6027390e8de285fa4d842f7 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..837f07e2858d5c65c14e3f262617e7d2985ce901 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..e570e22093339d1b71dc830edf656897c67f40e7 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5aaafc402a651a47595c25dbf71e382903e5022 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a4553323d8de71a1d5f1821bac50249c68197 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py new file mode 100644 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a693a82a36c052e51afb4a02077e7849a4abf1 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d8a2495c4baab89e5f90d84fa92f406a29ea10 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,195 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..0be0cd1a5f3cab09576c25a7e8e799c8097bfdfc --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9661a5704bd53f129788d2b9241f44cf00e1447dab8127d5a07675b1d6ca2ba +size 10444224 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd42432c910375fb1713b3bf29145e48b556498 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_359242d +ops = torch.ops._megablocks_359242d + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_359242d::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_version.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/bak.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/benchmark_util.py new file mode 100644 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3f977f360fd0ad5800c3b5da9ce57be794b9b8 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + import grouped_gemm + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +backend = grouped_gemm.backend if grouped_gemm_is_available() else None +ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..849b023b1a0f765fb1e26addf4670dc6db785a52 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/activation_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..a31770ba179fed06abf2da10102ccaeed1d3ee4e --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/all_to_all.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..3962c771c90012535aab058443f89c541a2e9236 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/arguments.py @@ -0,0 +1,100 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +import megablocks.grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d765bd04387a29bddd2789fae04821635b555e82 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..205727ff4d63f9e8dc9648acaac99a97f3394d6f --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmoe.py @@ -0,0 +1,327 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import stk.ops +import torch +from stk import Matrix + +import megablocks.ops as ops +# from megablocks.ops import ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..40b601d4a8eb59e80e1090f31f4172b8f7fb7549 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/gelu.py @@ -0,0 +1,43 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/glu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe0c915c307e7f7cade3ea3ff679399635fcd81 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/glu.py @@ -0,0 +1,223 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk.ops +import torch + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4acbd94f212ce906ace9de78dd3ffc6afa03f97e --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/memory_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +from megablocks.layers import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6f4d82441e3a9c185db5bfdf686d53790dde26 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mlp.py @@ -0,0 +1,574 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import stk +import stk.backend.triton_kernels +import stk.ops +import torch +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/moe.py @@ -0,0 +1,475 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + replicate_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b23213902f4567d7fdb0158cbcf5406c2b2aa601 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mpu.py @@ -0,0 +1,93 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +from megablocks.layers.arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9dcd9e433322482a80d5b95afccee5c12368f8 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/router.py @@ -0,0 +1,114 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +from megablocks.layers import common +from megablocks.layers.arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62db39a2dda5642f6baa9d9b949f7c31cf6d35 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.ops.binned_gather import binned_gather +from megablocks.ops.binned_scatter import binned_scatter +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum +from megablocks.ops.gather import gather +from megablocks.ops.histogram import histogram +from megablocks.ops.padded_gather import padded_gather +from megablocks.ops.padded_scatter import padded_scatter +from megablocks.ops.repeat import repeat +from megablocks.ops.replicate import replicate +from megablocks.ops.round_up import round_up +from megablocks.ops.scatter import scatter +from megablocks.ops.sort import sort +from megablocks.ops.sum import sum +from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..47b95301dbb35b0c3154d1ff2878d20792ce5cb2 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,60 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..89cce1b627137d7f987baca62fd6d82c5c04659a --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ce0d6f5a890a58b89e6a501216a9822341323f --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7c5137c900d66a959a76fa681436362a4df906 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..41b09a1233e8a996ff1062579cd0810d095ad1e6 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..77e06a29163aa335dc61e8c50932d5f83322ae1d --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9de8e65281c829ba63b5c60853dfa5bb0a333988 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa7b7c2f697f9454c5381e31ee78040be5b229e --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,403 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import stk +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..f272a7768dc6468e81bd2cd25294ca16a6826c08 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff81dd9456a04258346266e9e85871bda56c65b --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..81dde4e4b4c0f550c6027390e8de285fa4d842f7 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..837f07e2858d5c65c14e3f262617e7d2985ce901 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..e570e22093339d1b71dc830edf656897c67f40e7 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/round_up.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5aaafc402a651a47595c25dbf71e382903e5022 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a4553323d8de71a1d5f1821bac50249c68197 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sum.py new file mode 100644 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a693a82a36c052e51afb4a02077e7849a4abf1 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d8a2495c4baab89e5f90d84fa92f406a29ea10 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,195 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..abd969561ee41282ee024c3a0326ab521b4e7a00 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec45bdb77d89916e5c58e17ddb76148799e3cac07fa6f4e93bf9140d9b2039bb +size 11788400 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd42432c910375fb1713b3bf29145e48b556498 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_359242d +ops = torch.ops._megablocks_359242d + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_359242d::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_version.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/bak.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/benchmark_util.py new file mode 100644 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3f977f360fd0ad5800c3b5da9ce57be794b9b8 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + import grouped_gemm + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +backend = grouped_gemm.backend if grouped_gemm_is_available() else None +ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..849b023b1a0f765fb1e26addf4670dc6db785a52 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/activation_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..a31770ba179fed06abf2da10102ccaeed1d3ee4e --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/all_to_all.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..3962c771c90012535aab058443f89c541a2e9236 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/arguments.py @@ -0,0 +1,100 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +import megablocks.grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d765bd04387a29bddd2789fae04821635b555e82 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..205727ff4d63f9e8dc9648acaac99a97f3394d6f --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmoe.py @@ -0,0 +1,327 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import stk.ops +import torch +from stk import Matrix + +import megablocks.ops as ops +# from megablocks.ops import ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..40b601d4a8eb59e80e1090f31f4172b8f7fb7549 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/gelu.py @@ -0,0 +1,43 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/glu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe0c915c307e7f7cade3ea3ff679399635fcd81 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/glu.py @@ -0,0 +1,223 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk.ops +import torch + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4acbd94f212ce906ace9de78dd3ffc6afa03f97e --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/memory_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +from megablocks.layers import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6f4d82441e3a9c185db5bfdf686d53790dde26 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mlp.py @@ -0,0 +1,574 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import stk +import stk.backend.triton_kernels +import stk.ops +import torch +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/moe.py @@ -0,0 +1,475 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + replicate_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b23213902f4567d7fdb0158cbcf5406c2b2aa601 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mpu.py @@ -0,0 +1,93 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +from megablocks.layers.arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9dcd9e433322482a80d5b95afccee5c12368f8 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/router.py @@ -0,0 +1,114 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +from megablocks.layers import common +from megablocks.layers.arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62db39a2dda5642f6baa9d9b949f7c31cf6d35 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.ops.binned_gather import binned_gather +from megablocks.ops.binned_scatter import binned_scatter +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum +from megablocks.ops.gather import gather +from megablocks.ops.histogram import histogram +from megablocks.ops.padded_gather import padded_gather +from megablocks.ops.padded_scatter import padded_scatter +from megablocks.ops.repeat import repeat +from megablocks.ops.replicate import replicate +from megablocks.ops.round_up import round_up +from megablocks.ops.scatter import scatter +from megablocks.ops.sort import sort +from megablocks.ops.sum import sum +from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..47b95301dbb35b0c3154d1ff2878d20792ce5cb2 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,60 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..89cce1b627137d7f987baca62fd6d82c5c04659a --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ce0d6f5a890a58b89e6a501216a9822341323f --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7c5137c900d66a959a76fa681436362a4df906 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..41b09a1233e8a996ff1062579cd0810d095ad1e6 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..77e06a29163aa335dc61e8c50932d5f83322ae1d --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9de8e65281c829ba63b5c60853dfa5bb0a333988 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa7b7c2f697f9454c5381e31ee78040be5b229e --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,403 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import stk +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..f272a7768dc6468e81bd2cd25294ca16a6826c08 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff81dd9456a04258346266e9e85871bda56c65b --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..81dde4e4b4c0f550c6027390e8de285fa4d842f7 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..837f07e2858d5c65c14e3f262617e7d2985ce901 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..e570e22093339d1b71dc830edf656897c67f40e7 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/round_up.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5aaafc402a651a47595c25dbf71e382903e5022 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a4553323d8de71a1d5f1821bac50249c68197 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sum.py new file mode 100644 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a693a82a36c052e51afb4a02077e7849a4abf1 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d8a2495c4baab89e5f90d84fa92f406a29ea10 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,195 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..6ff924df77bd5ce379e5f717d59e29f407da7962 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99859c18a9a4e7ec11e7fa805e2225644f8e5f51e2546b4525ddf8e939f48874 +size 11832392 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd42432c910375fb1713b3bf29145e48b556498 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_359242d +ops = torch.ops._megablocks_359242d + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_359242d::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_version.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/bak.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/benchmark_util.py new file mode 100644 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3f977f360fd0ad5800c3b5da9ce57be794b9b8 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + import grouped_gemm + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +backend = grouped_gemm.backend if grouped_gemm_is_available() else None +ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..849b023b1a0f765fb1e26addf4670dc6db785a52 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/activation_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..a31770ba179fed06abf2da10102ccaeed1d3ee4e --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/all_to_all.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..3962c771c90012535aab058443f89c541a2e9236 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/arguments.py @@ -0,0 +1,100 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +import megablocks.grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d765bd04387a29bddd2789fae04821635b555e82 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..205727ff4d63f9e8dc9648acaac99a97f3394d6f --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmoe.py @@ -0,0 +1,327 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import stk.ops +import torch +from stk import Matrix + +import megablocks.ops as ops +# from megablocks.ops import ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..40b601d4a8eb59e80e1090f31f4172b8f7fb7549 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/gelu.py @@ -0,0 +1,43 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/glu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe0c915c307e7f7cade3ea3ff679399635fcd81 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/glu.py @@ -0,0 +1,223 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk.ops +import torch + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4acbd94f212ce906ace9de78dd3ffc6afa03f97e --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/memory_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +from megablocks.layers import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6f4d82441e3a9c185db5bfdf686d53790dde26 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mlp.py @@ -0,0 +1,574 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import stk +import stk.backend.triton_kernels +import stk.ops +import torch +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/moe.py @@ -0,0 +1,475 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + replicate_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b23213902f4567d7fdb0158cbcf5406c2b2aa601 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mpu.py @@ -0,0 +1,93 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +from megablocks.layers.arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9dcd9e433322482a80d5b95afccee5c12368f8 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/router.py @@ -0,0 +1,114 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +from megablocks.layers import common +from megablocks.layers.arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62db39a2dda5642f6baa9d9b949f7c31cf6d35 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.ops.binned_gather import binned_gather +from megablocks.ops.binned_scatter import binned_scatter +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum +from megablocks.ops.gather import gather +from megablocks.ops.histogram import histogram +from megablocks.ops.padded_gather import padded_gather +from megablocks.ops.padded_scatter import padded_scatter +from megablocks.ops.repeat import repeat +from megablocks.ops.replicate import replicate +from megablocks.ops.round_up import round_up +from megablocks.ops.scatter import scatter +from megablocks.ops.sort import sort +from megablocks.ops.sum import sum +from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..47b95301dbb35b0c3154d1ff2878d20792ce5cb2 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,60 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..89cce1b627137d7f987baca62fd6d82c5c04659a --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ce0d6f5a890a58b89e6a501216a9822341323f --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7c5137c900d66a959a76fa681436362a4df906 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..41b09a1233e8a996ff1062579cd0810d095ad1e6 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..77e06a29163aa335dc61e8c50932d5f83322ae1d --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9de8e65281c829ba63b5c60853dfa5bb0a333988 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa7b7c2f697f9454c5381e31ee78040be5b229e --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,403 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import stk +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..f272a7768dc6468e81bd2cd25294ca16a6826c08 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff81dd9456a04258346266e9e85871bda56c65b --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..81dde4e4b4c0f550c6027390e8de285fa4d842f7 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..837f07e2858d5c65c14e3f262617e7d2985ce901 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..e570e22093339d1b71dc830edf656897c67f40e7 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/round_up.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5aaafc402a651a47595c25dbf71e382903e5022 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a4553323d8de71a1d5f1821bac50249c68197 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sum.py new file mode 100644 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a693a82a36c052e51afb4a02077e7849a4abf1 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d8a2495c4baab89e5f90d84fa92f406a29ea10 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,195 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..a9116710809dc76b278bf71b0c78b35fa9c00c5f --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1d87aabde7c05e3b4d3fd9623b89db0f9b7caa444a2c8f590fc8227211cd2d3 +size 10456616 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd42432c910375fb1713b3bf29145e48b556498 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_359242d +ops = torch.ops._megablocks_359242d + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_359242d::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_version.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.11.0.dev0' diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py new file mode 100644 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3f977f360fd0ad5800c3b5da9ce57be794b9b8 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + import grouped_gemm + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +backend = grouped_gemm.backend if grouped_gemm_is_available() else None +ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..849b023b1a0f765fb1e26addf4670dc6db785a52 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..a31770ba179fed06abf2da10102ccaeed1d3ee4e --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..3962c771c90012535aab058443f89c541a2e9236 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py @@ -0,0 +1,100 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +import megablocks.grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/common.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d765bd04387a29bddd2789fae04821635b555e82 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..205727ff4d63f9e8dc9648acaac99a97f3394d6f --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py @@ -0,0 +1,327 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import stk.ops +import torch +from stk import Matrix + +import megablocks.ops as ops +# from megablocks.ops import ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..40b601d4a8eb59e80e1090f31f4172b8f7fb7549 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py @@ -0,0 +1,43 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe0c915c307e7f7cade3ea3ff679399635fcd81 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py @@ -0,0 +1,223 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk.ops +import torch + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4acbd94f212ce906ace9de78dd3ffc6afa03f97e --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +from megablocks.layers import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6f4d82441e3a9c185db5bfdf686d53790dde26 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py @@ -0,0 +1,574 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import stk +import stk.backend.triton_kernels +import stk.ops +import torch +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py @@ -0,0 +1,475 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + replicate_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b23213902f4567d7fdb0158cbcf5406c2b2aa601 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py @@ -0,0 +1,93 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +from megablocks.layers.arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/router.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9dcd9e433322482a80d5b95afccee5c12368f8 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/router.py @@ -0,0 +1,114 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +from megablocks.layers import common +from megablocks.layers.arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62db39a2dda5642f6baa9d9b949f7c31cf6d35 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.ops.binned_gather import binned_gather +from megablocks.ops.binned_scatter import binned_scatter +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum +from megablocks.ops.gather import gather +from megablocks.ops.histogram import histogram +from megablocks.ops.padded_gather import padded_gather +from megablocks.ops.padded_scatter import padded_scatter +from megablocks.ops.repeat import repeat +from megablocks.ops.replicate import replicate +from megablocks.ops.round_up import round_up +from megablocks.ops.scatter import scatter +from megablocks.ops.sort import sort +from megablocks.ops.sum import sum +from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..47b95301dbb35b0c3154d1ff2878d20792ce5cb2 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,60 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..89cce1b627137d7f987baca62fd6d82c5c04659a --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ce0d6f5a890a58b89e6a501216a9822341323f --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7c5137c900d66a959a76fa681436362a4df906 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..41b09a1233e8a996ff1062579cd0810d095ad1e6 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..77e06a29163aa335dc61e8c50932d5f83322ae1d --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9de8e65281c829ba63b5c60853dfa5bb0a333988 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa7b7c2f697f9454c5381e31ee78040be5b229e --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,403 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import stk +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..f272a7768dc6468e81bd2cd25294ca16a6826c08 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff81dd9456a04258346266e9e85871bda56c65b --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..81dde4e4b4c0f550c6027390e8de285fa4d842f7 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..837f07e2858d5c65c14e3f262617e7d2985ce901 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..e570e22093339d1b71dc830edf656897c67f40e7 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5aaafc402a651a47595c25dbf71e382903e5022 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a4553323d8de71a1d5f1821bac50249c68197 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py new file mode 100644 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a693a82a36c052e51afb4a02077e7849a4abf1 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d8a2495c4baab89e5f90d84fa92f406a29ea10 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,195 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7d62280654256c29f19dc4a96c5b895591d064be --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d15bbd83210a53d26b152adb3388900213fd5ff70d5b77a61acac53b6e89fbe +size 11835920 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd42432c910375fb1713b3bf29145e48b556498 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_359242d +ops = torch.ops._megablocks_359242d + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_359242d::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_version.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.11.0.dev0' diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py new file mode 100644 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3f977f360fd0ad5800c3b5da9ce57be794b9b8 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + import grouped_gemm + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +backend = grouped_gemm.backend if grouped_gemm_is_available() else None +ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..849b023b1a0f765fb1e26addf4670dc6db785a52 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/activation_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..a31770ba179fed06abf2da10102ccaeed1d3ee4e --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/all_to_all.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..3962c771c90012535aab058443f89c541a2e9236 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py @@ -0,0 +1,100 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +import megablocks.grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/common.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d765bd04387a29bddd2789fae04821635b555e82 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..205727ff4d63f9e8dc9648acaac99a97f3394d6f --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py @@ -0,0 +1,327 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import stk.ops +import torch +from stk import Matrix + +import megablocks.ops as ops +# from megablocks.ops import ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/gelu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..40b601d4a8eb59e80e1090f31f4172b8f7fb7549 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/gelu.py @@ -0,0 +1,43 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe0c915c307e7f7cade3ea3ff679399635fcd81 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py @@ -0,0 +1,223 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk.ops +import torch + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4acbd94f212ce906ace9de78dd3ffc6afa03f97e --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +from megablocks.layers import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6f4d82441e3a9c185db5bfdf686d53790dde26 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py @@ -0,0 +1,574 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import stk +import stk.backend.triton_kernels +import stk.ops +import torch +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py @@ -0,0 +1,475 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + replicate_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b23213902f4567d7fdb0158cbcf5406c2b2aa601 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py @@ -0,0 +1,93 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +from megablocks.layers.arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/router.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9dcd9e433322482a80d5b95afccee5c12368f8 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/router.py @@ -0,0 +1,114 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +from megablocks.layers import common +from megablocks.layers.arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62db39a2dda5642f6baa9d9b949f7c31cf6d35 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.ops.binned_gather import binned_gather +from megablocks.ops.binned_scatter import binned_scatter +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum +from megablocks.ops.gather import gather +from megablocks.ops.histogram import histogram +from megablocks.ops.padded_gather import padded_gather +from megablocks.ops.padded_scatter import padded_scatter +from megablocks.ops.repeat import repeat +from megablocks.ops.replicate import replicate +from megablocks.ops.round_up import round_up +from megablocks.ops.scatter import scatter +from megablocks.ops.sort import sort +from megablocks.ops.sum import sum +from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..47b95301dbb35b0c3154d1ff2878d20792ce5cb2 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,60 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..89cce1b627137d7f987baca62fd6d82c5c04659a --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ce0d6f5a890a58b89e6a501216a9822341323f --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7c5137c900d66a959a76fa681436362a4df906 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..41b09a1233e8a996ff1062579cd0810d095ad1e6 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..77e06a29163aa335dc61e8c50932d5f83322ae1d --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9de8e65281c829ba63b5c60853dfa5bb0a333988 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa7b7c2f697f9454c5381e31ee78040be5b229e --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,403 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import stk +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..f272a7768dc6468e81bd2cd25294ca16a6826c08 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff81dd9456a04258346266e9e85871bda56c65b --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..81dde4e4b4c0f550c6027390e8de285fa4d842f7 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..837f07e2858d5c65c14e3f262617e7d2985ce901 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..e570e22093339d1b71dc830edf656897c67f40e7 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5aaafc402a651a47595c25dbf71e382903e5022 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a4553323d8de71a1d5f1821bac50249c68197 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py new file mode 100644 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a693a82a36c052e51afb4a02077e7849a4abf1 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d8a2495c4baab89e5f90d84fa92f406a29ea10 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,195 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_359242d.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..f58a6ba85d95e023464dafbb7d4f75d60bf639f1 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_359242d.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:122a39925ccc50597aac33bcd4eec4d2d3ff29846c3a79671da2c554d4967777 +size 17748280 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd42432c910375fb1713b3bf29145e48b556498 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_359242d +ops = torch.ops._megablocks_359242d + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_359242d::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_version.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c55783177af19bc03654c730c4892df8f8532279 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.11.0.dev0' diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/bak.__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/bak.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/benchmark_util.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/benchmark_util.py new file mode 100644 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3f977f360fd0ad5800c3b5da9ce57be794b9b8 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + import grouped_gemm + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +backend = grouped_gemm.backend if grouped_gemm_is_available() else None +ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..849b023b1a0f765fb1e26addf4670dc6db785a52 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/activation_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..a31770ba179fed06abf2da10102ccaeed1d3ee4e --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/all_to_all.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/arguments.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..3962c771c90012535aab058443f89c541a2e9236 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/arguments.py @@ -0,0 +1,100 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +import megablocks.grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/common.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmlp_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..d765bd04387a29bddd2789fae04821635b555e82 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmoe.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmoe.py new file mode 100644 index 0000000000000000000000000000000000000000..205727ff4d63f9e8dc9648acaac99a97f3394d6f --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmoe.py @@ -0,0 +1,327 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import stk.ops +import torch +from stk import Matrix + +import megablocks.ops as ops +# from megablocks.ops import ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/gelu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..40b601d4a8eb59e80e1090f31f4172b8f7fb7549 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/gelu.py @@ -0,0 +1,43 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/glu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/glu.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe0c915c307e7f7cade3ea3ff679399635fcd81 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/glu.py @@ -0,0 +1,223 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk.ops +import torch + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/memory_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4acbd94f212ce906ace9de78dd3ffc6afa03f97e --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/memory_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +from megablocks.layers import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mlp.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6f4d82441e3a9c185db5bfdf686d53790dde26 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mlp.py @@ -0,0 +1,574 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import stk +import stk.backend.triton_kernels +import stk.ops +import torch +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/moe.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/moe.py @@ -0,0 +1,475 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + replicate_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mpu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b23213902f4567d7fdb0158cbcf5406c2b2aa601 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mpu.py @@ -0,0 +1,93 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +from megablocks.layers.arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/router.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/router.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9dcd9e433322482a80d5b95afccee5c12368f8 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/router.py @@ -0,0 +1,114 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +from megablocks.layers import common +from megablocks.layers.arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/sharedexpert_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0f62db39a2dda5642f6baa9d9b949f7c31cf6d35 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -0,0 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.ops.binned_gather import binned_gather +from megablocks.ops.binned_scatter import binned_scatter +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum +from megablocks.ops.gather import gather +from megablocks.ops.histogram import histogram +from megablocks.ops.padded_gather import padded_gather +from megablocks.ops.padded_scatter import padded_scatter +from megablocks.ops.repeat import repeat +from megablocks.ops.replicate import replicate +from megablocks.ops.round_up import round_up +from megablocks.ops.scatter import scatter +from megablocks.ops.sort import sort +from megablocks.ops.sum import sum +from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..47b95301dbb35b0c3154d1ff2878d20792ce5cb2 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,60 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..89cce1b627137d7f987baca62fd6d82c5c04659a --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ce0d6f5a890a58b89e6a501216a9822341323f --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7c5137c900d66a959a76fa681436362a4df906 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..41b09a1233e8a996ff1062579cd0810d095ad1e6 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..77e06a29163aa335dc61e8c50932d5f83322ae1d --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..9de8e65281c829ba63b5c60853dfa5bb0a333988 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa7b7c2f697f9454c5381e31ee78040be5b229e --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,403 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import stk +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..f272a7768dc6468e81bd2cd25294ca16a6826c08 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff81dd9456a04258346266e9e85871bda56c65b --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..81dde4e4b4c0f550c6027390e8de285fa4d842f7 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..837f07e2858d5c65c14e3f262617e7d2985ce901 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/repeat.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..e570e22093339d1b71dc830edf656897c67f40e7 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/round_up.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/round_up.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5aaafc402a651a47595c25dbf71e382903e5022 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from stk.backend.autocast import custom_bwd, custom_fwd + +from megablocks.backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a4553323d8de71a1d5f1821bac50249c68197 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from megablocks import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sum.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sum.py new file mode 100644 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a693a82a36c052e51afb4a02077e7849a4abf1 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from megablocks._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply