kernel
drbh commited on
Commit
a585153
·
1 Parent(s): 359242d

feat: add build output

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -1
  2. build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py +195 -0
  3. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +3 -0
  4. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +9 -0
  5. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py +6 -0
  6. build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py +2 -0
  7. build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py +543 -0
  8. build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py +23 -0
  9. build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py +35 -0
  10. build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +26 -0
  11. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py +10 -0
  12. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py +33 -0
  13. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py +54 -0
  14. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py +100 -0
  15. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py +26 -0
  16. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py +42 -0
  17. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py +327 -0
  18. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py +43 -0
  19. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py +223 -0
  20. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py +102 -0
  21. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py +574 -0
  22. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py +475 -0
  23. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py +93 -0
  24. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py +114 -0
  25. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py +30 -0
  26. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py +35 -0
  27. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +60 -0
  28. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +37 -0
  29. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +59 -0
  30. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py +52 -0
  31. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +38 -0
  32. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py +27 -0
  33. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +78 -0
  34. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +403 -0
  35. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +55 -0
  36. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +98 -0
  37. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +66 -0
  38. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +149 -0
  39. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py +10 -0
  40. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py +36 -0
  41. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py +14 -0
  42. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +72 -0
  43. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py +38 -0
  44. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +85 -0
  45. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py +9 -0
  46. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py +45 -0
  47. build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py +195 -0
  48. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +3 -0
  49. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +9 -0
  50. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_version.py +6 -0
.gitattributes CHANGED
@@ -32,4 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.so filter=lfs diff=lfs merge=lfs -text
build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from ._ops import ops
7
+
8
+ from megablocks.layers.arguments import Arguments
9
+ from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
10
+ from megablocks.layers.glu import SparseGLU
11
+ from megablocks.layers.mlp import MLP, SparseMLP
12
+ from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
13
+
14
+ # This section contains the direct kernel exports (not inlcuded in the original code)
15
+ def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
16
+ """
17
+ Compute exclusive cumulative sum along the specified dimension.
18
+
19
+ Args:
20
+ x: Input tensor
21
+ dim: Dimension along which to compute cumsum
22
+ out: Output tensor (modified in-place)
23
+
24
+ Returns:
25
+ The output tensor
26
+ """
27
+ result = ops.exclusive_cumsum(x, dim)
28
+ out.copy_(result)
29
+ return out
30
+
31
+
32
+ def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Compute inclusive cumulative sum along the specified dimension.
35
+
36
+ Args:
37
+ x: Input tensor
38
+ dim: Dimension along which to compute cumsum
39
+ out: Output tensor (modified in-place)
40
+
41
+ Returns:
42
+ The output tensor
43
+ """
44
+ result = ops.inclusive_cumsum(x, dim)
45
+ out.copy_(result)
46
+ return out
47
+
48
+
49
+ def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
50
+ """
51
+ Compute histogram of input tensor values.
52
+
53
+ Args:
54
+ x: Input tensor
55
+ num_bins: Number of histogram bins
56
+
57
+ Returns:
58
+ Histogram tensor with counts for each bin
59
+ """
60
+ return ops.histogram(x, num_bins)
61
+
62
+
63
+ def indices(
64
+ padded_bins: torch.Tensor,
65
+ block_size: int,
66
+ output_block_rows: int,
67
+ output_block_columns: int,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Construct indices from padded bins for sparse operations.
71
+
72
+ Args:
73
+ padded_bins: Tensor containing bin boundaries
74
+ block_size: Size of each block
75
+ output_block_rows: Number of rows in output blocks
76
+ output_block_columns: Number of columns in output blocks
77
+
78
+ Returns:
79
+ Tensor containing constructed indices
80
+ """
81
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
82
+
83
+
84
+ def replicate_forward(
85
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
86
+ ) -> torch.Tensor:
87
+ """
88
+ Forward pass of replicate operation - replicate values according to bin sizes.
89
+
90
+ Args:
91
+ x: Input tensor with values to replicate
92
+ bins: Tensor containing bin sizes
93
+ out: Output tensor (modified in-place)
94
+
95
+ Returns:
96
+ The output tensor
97
+ """
98
+ return ops.replicate_forward(x, bins, out)
99
+
100
+
101
+ def replicate_backward(
102
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
103
+ ) -> torch.Tensor:
104
+ """
105
+ Backward pass of replicate operation - reduce gradients back to bins.
106
+
107
+ Args:
108
+ grad: Gradient tensor to reduce
109
+ bins: Tensor containing bin sizes
110
+ out: Output tensor (modified in-place)
111
+
112
+ Returns:
113
+ The output tensor
114
+ """
115
+ return ops.replicate_backward(grad, bins, out)
116
+
117
+
118
+ def sort(
119
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
120
+ ) -> torch.Tensor:
121
+ """
122
+ Radix sort with index tracking.
123
+
124
+ Args:
125
+ x: Input tensor to sort
126
+ end_bit: Number of bits to consider in sorting
127
+ x_out: Output tensor for sorted values
128
+ iota_out: Output tensor for sorted indices
129
+
130
+ Returns:
131
+ The sorted values tensor
132
+ """
133
+ return ops.sort(x, end_bit, x_out, iota_out)
134
+
135
+
136
+ # Convenience functions for common use cases
137
+ def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
138
+ """
139
+ Compute cumulative sum with automatic output allocation.
140
+
141
+ Args:
142
+ x: Input tensor
143
+ dim: Dimension along which to compute cumsum (default: last dimension)
144
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
145
+
146
+ Returns:
147
+ New tensor containing the cumulative sum
148
+ """
149
+ out = torch.empty_like(x)
150
+ if exclusive:
151
+ return exclusive_cumsum(x, dim, out)
152
+ else:
153
+ return inclusive_cumsum(x, dim, out)
154
+
155
+
156
+ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
157
+ """
158
+ Sort tensor and return both sorted values and indices.
159
+
160
+ Args:
161
+ x: Input tensor to sort
162
+ end_bit: Number of bits to consider in sorting
163
+
164
+ Returns:
165
+ Tuple of (sorted_values, sorted_indices)
166
+ """
167
+ x_out = torch.empty_like(x)
168
+ iota_out = torch.empty_like(x)
169
+ sort(x, end_bit, x_out, iota_out)
170
+ return x_out, iota_out
171
+
172
+
173
+ # Export public API
174
+ __all__ = [
175
+ # Direct kernel exports
176
+ "exclusive_cumsum",
177
+ "inclusive_cumsum",
178
+ "histogram",
179
+ "indices",
180
+ "replicate_forward",
181
+ "replicate_backward",
182
+ "sort",
183
+ "cumsum",
184
+ "argsort",
185
+ # Original exports
186
+ "Arguments",
187
+ "ParallelDroplessMLP",
188
+ "dMoE",
189
+ "SparseGLU",
190
+ "MLP",
191
+ "SparseMLP",
192
+ "MoE",
193
+ "ParallelMLP",
194
+ "get_load_balancing_loss",
195
+ ]
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36a918d61308a0acbc880516a42fcc2c4dc393f25655326e52128465c9501709
3
+ size 10456376
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _megablocks_359242d
3
+ ops = torch.ops._megablocks_359242d
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_megablocks_359242d::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """The MegaBlocks Version."""
5
+
6
+ __version__ = '0.11.0.dev0'
build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ def assert_is_tensor(x, ndim):
10
+ if x.ndim != ndim:
11
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
12
+
13
+
14
+ def assert_is_matrix(x):
15
+ assert_is_tensor(x, 2)
16
+
17
+
18
+ def assert_is_vector(x):
19
+ if x.ndim != 1:
20
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
21
+
22
+
23
+ def assert_equal(a, b):
24
+ if a != b:
25
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
26
+
27
+
28
+ # a: (tokens, hidden_size), real.
29
+ # indices: (tokens * top_k), integer.
30
+ # bin_ids: (tokens * top_k), integer.
31
+ # weights: (tokens * top_k), real.
32
+ # bins: (num_experts), integer.
33
+ # padded_bins: (num_experts), integer.
34
+ @triton.autotune(
35
+ configs=[
36
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
37
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
38
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
39
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
40
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
41
+ ],
42
+ key=['NUM_COLUMNS'],
43
+ )
44
+ @triton.jit
45
+ def _padded_copy(
46
+ a,
47
+ b,
48
+ indices,
49
+ bin_ids,
50
+ weights,
51
+ bins,
52
+ padded_bins,
53
+ NUM_COLUMNS: tl.constexpr,
54
+ TOP_K: tl.constexpr,
55
+ BLOCK_X: tl.constexpr,
56
+ A_TO_B: tl.constexpr,
57
+ SCALE: tl.constexpr,
58
+ ):
59
+ # Our index into array 'a'.
60
+ index_a = tl.load(indices + tl.program_id(0))
61
+
62
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
63
+ # number of rows since they could be padded.
64
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
65
+
66
+ # Now we know what bin we're assigned to, but we need to know how
67
+ # many threadblocks were assigned to earlier bins so we can offset
68
+ # in our bin properly.
69
+ offset_in_bin = tl.program_id(0)
70
+ if bin_idx > 0:
71
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
72
+
73
+ # Load the starting index of our bin in array 'b'.
74
+ index_b = offset_in_bin
75
+ if bin_idx > 0:
76
+ index_b += tl.load(padded_bins + bin_idx - 1)
77
+
78
+ # Offset the input and output pointers.
79
+ #
80
+ # If we're going from A to B, divide the input index to copy
81
+ # the same input repeatedly. If we're going from B to A we
82
+ # need to reduce the result. Using atomics is slow, so we
83
+ # do the reduce step in a second kernel.
84
+ offset = index_a // TOP_K if A_TO_B else index_a
85
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
86
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
87
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
88
+
89
+ # Load the scale, if requested.
90
+ scale = tl.load(weights + index_a) if SCALE else 1
91
+
92
+ # Swap the pointers depending on the direction.
93
+ iptr = a if A_TO_B else b
94
+ optr = b if A_TO_B else a
95
+
96
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
97
+ for _ in range(iterations):
98
+ mask = offsets < NUM_COLUMNS
99
+ x = tl.load(iptr + offsets, mask=mask)
100
+ x = x.to(tl.float32) * scale.to(tl.float32)
101
+
102
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
103
+
104
+ offsets += BLOCK_X
105
+
106
+
107
+ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
108
+ # Validate the input shapes.
109
+ assert_is_matrix(x)
110
+ assert_is_vector(indices)
111
+ assert_is_vector(bin_ids)
112
+ assert_is_vector(bins)
113
+ assert_is_vector(padded_bins)
114
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
115
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
116
+ assert_equal(bins.size(), padded_bins.size())
117
+
118
+ if weights is not None:
119
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
120
+
121
+ # NOTE: Because of the padding, the output size is dynamic.
122
+ # We load the final padded bin bound to get the output rows.
123
+ output_rows = padded_bins[-1].cpu().item()
124
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
125
+ _padded_copy[(indices.shape[0],)](
126
+ x,
127
+ out,
128
+ indices,
129
+ bin_ids,
130
+ weights,
131
+ bins,
132
+ padded_bins,
133
+ NUM_COLUMNS=x.shape[1],
134
+ A_TO_B=True,
135
+ TOP_K=top_k,
136
+ SCALE=weights is not None,
137
+ )
138
+ return out
139
+
140
+
141
+ def gather(x, indices, bin_ids, weights, bins, top_k):
142
+ # Validate the input shapes.
143
+ assert_is_matrix(x)
144
+ assert_is_vector(indices)
145
+ assert_is_vector(bin_ids)
146
+ assert_is_vector(bins)
147
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
148
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
149
+
150
+ if weights is not None:
151
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
152
+
153
+ # NOTE: There is no padding so the output rows equals the
154
+ # input rows multiplied by top_k.
155
+ output_rows = x.shape[0] * top_k
156
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
157
+ _padded_copy[(indices.shape[0],)](
158
+ x,
159
+ out,
160
+ indices,
161
+ bin_ids,
162
+ weights,
163
+ bins,
164
+ bins,
165
+ NUM_COLUMNS=x.shape[1],
166
+ A_TO_B=True,
167
+ TOP_K=top_k,
168
+ SCALE=weights is not None,
169
+ )
170
+ return out
171
+
172
+
173
+ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
174
+ # Validate the input shapes.
175
+ assert_is_matrix(x)
176
+ assert_is_vector(indices)
177
+ assert_is_vector(bin_ids)
178
+ assert_is_vector(bins)
179
+ assert_is_vector(padded_bins)
180
+ assert_equal(indices.shape[0], bin_ids.shape[0])
181
+ assert_equal(bins.size(), padded_bins.size())
182
+
183
+ if weights is not None:
184
+ assert_equal(indices.shape[0], weights.shape[0])
185
+
186
+ tokens = indices.shape[0] // top_k
187
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
188
+ _padded_copy[(indices.shape[0],)](
189
+ out,
190
+ x,
191
+ indices,
192
+ bin_ids,
193
+ weights,
194
+ bins,
195
+ padded_bins,
196
+ NUM_COLUMNS=x.shape[1],
197
+ A_TO_B=False,
198
+ TOP_K=top_k,
199
+ SCALE=weights is not None,
200
+ )
201
+
202
+ # Reduce along the top-k dimension, if needed.
203
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
204
+
205
+
206
+ def scatter(x, indices, bin_ids, weights, bins, top_k):
207
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
208
+
209
+
210
+ # x: (tokens, top_k, hidden_size), real
211
+ # grad: (tokens, hidden_size), real.
212
+ # wgrad: (tokens, top_k), real.
213
+ # indices: (tokens * top_k), integer.
214
+ # bin_ids: (tokens * top_k), integer.
215
+ # bins: (num_experts), integer.
216
+ # padded_bins: (num_experts), integer.
217
+ @triton.autotune(
218
+ configs=[
219
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
220
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
221
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
222
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
223
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
224
+ ],
225
+ key=['NUM_COLUMNS'],
226
+ )
227
+ @triton.jit
228
+ def _padded_copy_wgrad(
229
+ x,
230
+ grad,
231
+ wgrad,
232
+ indices,
233
+ bin_ids,
234
+ bins,
235
+ padded_bins,
236
+ NUM_COLUMNS: tl.constexpr,
237
+ TOP_K: tl.constexpr,
238
+ BLOCK_X: tl.constexpr,
239
+ ):
240
+ # Our index into 'tokens * top_k'.
241
+ index_out = tl.load(indices + tl.program_id(0))
242
+
243
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
244
+ # number of rows since they could be padded.
245
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
246
+
247
+ # Now we know what bin we're assigned to, but we need to know how
248
+ # many threadblocks were assigned to earlier bins so we can offset
249
+ # in our bin properly.
250
+ offset_in_bin = tl.program_id(0)
251
+ if bin_idx > 0:
252
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
253
+
254
+ # Load the starting index of our bin in array 'x'.
255
+ index_x = offset_in_bin
256
+ if bin_idx > 0:
257
+ index_x += tl.load(padded_bins + bin_idx - 1)
258
+
259
+ # Offset the input and output pointers.
260
+ wgrad += index_out
261
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
262
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
263
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
264
+
265
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
266
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
267
+ for _ in range(iterations):
268
+ mask = offsets < NUM_COLUMNS
269
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
270
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
271
+ acc += data * scale
272
+ offsets += BLOCK_X
273
+
274
+ # Reduce to get the final result and store.
275
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
276
+ tl.store(wgrad, out)
277
+
278
+
279
+ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
280
+ # Validate the input shapes.
281
+ assert_is_matrix(x)
282
+ assert_is_matrix(grad)
283
+ assert_is_vector(indices)
284
+ assert_is_vector(bin_ids)
285
+ assert_is_vector(bins)
286
+ assert_is_vector(padded_bins)
287
+ assert_equal(indices.shape[0], bin_ids.shape[0])
288
+ assert_equal(bins.size(), padded_bins.size())
289
+
290
+ tokens = indices.shape[0] // top_k
291
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
292
+ _padded_copy_wgrad[(indices.shape[0],)](
293
+ x,
294
+ grad,
295
+ out,
296
+ indices,
297
+ bin_ids,
298
+ bins,
299
+ padded_bins,
300
+ NUM_COLUMNS=x.shape[1],
301
+ TOP_K=top_k,
302
+ )
303
+ return out
304
+
305
+
306
+ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
307
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
308
+
309
+
310
+ # a: (tokens, hidden_size), real.
311
+ # b: (num_experts, expert_capacity, num_columns), real.
312
+ # indices: (tokens * top_k), integer.
313
+ # weights: (tokens * top_k), real.
314
+ # bins: (num_experts), integer.
315
+ @triton.autotune(
316
+ configs=[
317
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
318
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
319
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
320
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
321
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
322
+ ],
323
+ key=['NUM_COLUMNS'],
324
+ )
325
+ @triton.jit
326
+ def _binned_copy(
327
+ a,
328
+ b,
329
+ num_experts,
330
+ expert_capacity,
331
+ indices,
332
+ weights,
333
+ bins,
334
+ NUM_COLUMNS: tl.constexpr,
335
+ TOP_K: tl.constexpr,
336
+ BLOCK_X: tl.constexpr,
337
+ A_TO_B: tl.constexpr,
338
+ SCALE: tl.constexpr,
339
+ ):
340
+ # Load our indices into the output.
341
+ expert_idx = tl.program_id(0)
342
+ entry_idx = tl.program_id(1)
343
+
344
+ # Calculate our offset into the output.
345
+ index_b = expert_idx * expert_capacity + entry_idx
346
+
347
+ # Load the index bounds for our bin and calculate
348
+ # the number of tokens assigned to our expert.
349
+ start = 0
350
+ if expert_idx > 0:
351
+ start = tl.load(bins + expert_idx - 1)
352
+ end = tl.load(bins + expert_idx)
353
+ num_tokens = end - start
354
+
355
+ # Calculate our offset into the input. If we don't
356
+ # have an input exit early.
357
+ if entry_idx >= num_tokens:
358
+ return
359
+ index_a = tl.load(indices + start + entry_idx)
360
+
361
+ # Offset the input and output pointers.
362
+ #
363
+ # If we're going from A to B, divide the input index to copy
364
+ # the same input repeatedly. If we're going from B to A we
365
+ # need to reduce the result. Using atomics is slow, so we
366
+ # do the reduce step in a second kernel.
367
+ offset = index_a // TOP_K if A_TO_B else index_a
368
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
369
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
370
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
371
+
372
+ # Load the scale, if requested.
373
+ scale = tl.load(weights + index_a) if SCALE else 1
374
+
375
+ # Swap the pointers depending on the direction.
376
+ #
377
+ # NOTE: We need to zero the output in both directions.
378
+ iptr = a if A_TO_B else b
379
+ optr = b if A_TO_B else a
380
+
381
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
382
+ for _ in range(iterations):
383
+ mask = offsets < NUM_COLUMNS
384
+ x = tl.load(iptr + offsets, mask=mask)
385
+ x = x.to(tl.float32) * scale.to(tl.float32)
386
+
387
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
388
+
389
+ offsets += BLOCK_X
390
+
391
+
392
+ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
393
+ # Validate the input shapes.
394
+ assert_is_matrix(x)
395
+ assert_is_vector(indices)
396
+ assert_is_vector(bins)
397
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
398
+
399
+ if weights is not None:
400
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
401
+
402
+ num_experts = bins.shape[0]
403
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
404
+
405
+ _binned_copy[(num_experts, expert_capacity)](
406
+ x,
407
+ out,
408
+ num_experts,
409
+ expert_capacity,
410
+ indices,
411
+ weights,
412
+ bins,
413
+ NUM_COLUMNS=x.shape[1],
414
+ A_TO_B=True,
415
+ TOP_K=top_k,
416
+ SCALE=weights is not None,
417
+ )
418
+ return out
419
+
420
+
421
+ def binned_scatter(x, indices, weights, bins, top_k):
422
+ # Validate the input shapes.
423
+ assert_is_tensor(x, 3)
424
+ assert_is_vector(indices)
425
+ assert_is_vector(bins)
426
+ assert_equal(bins.shape[0], x.shape[0])
427
+
428
+ if weights is not None:
429
+ assert_equal(indices.shape[0], weights.shape[0])
430
+
431
+ num_experts, expert_capacity, hidden_size = x.shape
432
+ tokens = indices.shape[0] // top_k
433
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
434
+ _binned_copy[(num_experts, expert_capacity)](
435
+ out,
436
+ x,
437
+ num_experts,
438
+ expert_capacity,
439
+ indices,
440
+ weights,
441
+ bins,
442
+ NUM_COLUMNS=hidden_size,
443
+ A_TO_B=False,
444
+ TOP_K=top_k,
445
+ SCALE=weights is not None,
446
+ )
447
+
448
+ # Reduce along the top-k dimension, if needed.
449
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
450
+
451
+
452
+ # a: (tokens, hidden_size), real.
453
+ # b: (num_experts, expert_capacity, num_columns), real.
454
+ # indices: (tokens * top_k), integer.
455
+ # weights: (tokens * top_k), real.
456
+ # bins: (num_experts), integer.
457
+ @triton.autotune(
458
+ configs=[
459
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
460
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
461
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
462
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
463
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
464
+ ],
465
+ key=['NUM_COLUMNS'],
466
+ )
467
+ @triton.jit
468
+ def _binned_copy_wgrad(
469
+ x,
470
+ grad,
471
+ wgrad,
472
+ num_experts,
473
+ expert_capacity,
474
+ indices,
475
+ bins,
476
+ NUM_COLUMNS: tl.constexpr,
477
+ TOP_K: tl.constexpr,
478
+ BLOCK_X: tl.constexpr,
479
+ ):
480
+ # Load our indices into the output.
481
+ expert_idx = tl.program_id(0)
482
+ entry_idx = tl.program_id(1)
483
+
484
+ # Calculate our offset into the output.
485
+ index_x = expert_idx * expert_capacity + entry_idx
486
+
487
+ # Load the index bounds for our bin and calculate
488
+ # the number of tokens assigned to our expert.
489
+ start = 0
490
+ if expert_idx > 0:
491
+ start = tl.load(bins + expert_idx - 1)
492
+ end = tl.load(bins + expert_idx)
493
+ num_tokens = end - start
494
+
495
+ # Calculate our offset into the input. If we don't
496
+ # have an input exit early.
497
+ if entry_idx >= num_tokens:
498
+ return
499
+ index_out = tl.load(indices + start + entry_idx)
500
+
501
+ # Offset the input and output pointers.
502
+ wgrad += index_out
503
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
504
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
505
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
506
+
507
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
508
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
509
+ for _ in range(iterations):
510
+ mask = offsets < NUM_COLUMNS
511
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
512
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
513
+ acc += data * scale
514
+ offsets += BLOCK_X
515
+
516
+ # Reduce to get the final result and store.
517
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
518
+ tl.store(wgrad, out)
519
+
520
+
521
+ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
522
+ # Validate the input shapes.
523
+ assert_is_tensor(x, 3)
524
+ assert_is_matrix(grad)
525
+ assert_is_vector(indices)
526
+ assert_is_vector(bins)
527
+ assert_equal(bins.shape[0], x.shape[0])
528
+
529
+ num_experts, expert_capacity, hidden_size = x.shape
530
+ tokens = indices.shape[0] // top_k
531
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
532
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
533
+ x,
534
+ grad,
535
+ out,
536
+ num_experts,
537
+ expert_capacity,
538
+ indices,
539
+ bins,
540
+ NUM_COLUMNS=hidden_size,
541
+ TOP_K=top_k,
542
+ )
543
+ return out
build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from megablocks_moe.megablocks import (
2
+ MoE,
3
+ dMoE,
4
+ get_load_balancing_loss,
5
+ ParallelMLP,
6
+ ParallelDroplessMLP,
7
+ SparseMLP,
8
+ MLP,
9
+ SparseGLU,
10
+ Arguments,
11
+ )
12
+
13
+ __all__ = [
14
+ "MoE",
15
+ "dMoE",
16
+ "get_load_balancing_loss",
17
+ "ParallelMLP",
18
+ "ParallelDroplessMLP",
19
+ "SparseMLP",
20
+ "MLP",
21
+ "SparseGLU",
22
+ "Arguments",
23
+ ]
build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def log_benchmark(name, arguments, time, std):
9
+ print('=' * 60)
10
+ print(f'{name} Benchmark')
11
+ print('Benchmark Parameters:')
12
+ for (key, value) in arguments.items():
13
+ print(f'{key} = {value}')
14
+ print('Results:')
15
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
16
+ print('=' * 60)
17
+
18
+
19
+ def benchmark_function(fn, iterations=100, warmup=10):
20
+ # Warmup iterations.
21
+ for _ in range(warmup):
22
+ fn()
23
+
24
+ times = []
25
+ for i in range(iterations):
26
+ start = torch.cuda.Event(enable_timing=True)
27
+ end = torch.cuda.Event(enable_timing=True)
28
+
29
+ start.record()
30
+ fn()
31
+ end.record()
32
+
33
+ torch.cuda.synchronize()
34
+ times.append(start.elapsed_time(end))
35
+ return np.mean(times), np.std(times)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import warnings
4
+
5
+ _grouped_gemm_is_available: bool = False
6
+ try:
7
+ import grouped_gemm
8
+ _grouped_gemm_is_available = True
9
+ except ImportError as error:
10
+ warnings.warn('Grouped GEMM not available.')
11
+
12
+
13
+ def grouped_gemm_is_available():
14
+ return _grouped_gemm_is_available
15
+
16
+
17
+ def assert_grouped_gemm_is_available():
18
+ msg = (
19
+ 'Grouped GEMM not available. Please run '
20
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
21
+ )
22
+ assert _grouped_gemm_is_available, msg
23
+
24
+
25
+ backend = grouped_gemm.backend if grouped_gemm_is_available() else None
26
+ ops = grouped_gemm.ops if grouped_gemm_is_available() else None
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # from megablocks.layers.dmoe import dMoE
5
+ from megablocks.layers.moe import MoE
6
+
7
+ __all__ = [
8
+ 'MoE',
9
+ # 'dMoE',
10
+ ]
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Callable, Union
5
+
6
+ import torch
7
+ from stk import Matrix
8
+
9
+
10
+ def act_fn(
11
+ x: Matrix,
12
+ function: Callable,
13
+ return_grad_fn: bool = False,
14
+ **kwargs,
15
+ ) -> Union[tuple[Matrix, Any] | Matrix]:
16
+ assert isinstance(x, Matrix)
17
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
18
+ if return_grad_fn:
19
+ x.data.requires_grad = True
20
+ out = function(x.data, **kwargs)
21
+ y = Matrix(
22
+ x.size(),
23
+ out,
24
+ x.row_indices,
25
+ x.column_indices,
26
+ x.offsets,
27
+ x.column_indices_t,
28
+ x.offsets_t,
29
+ x.block_offsets_t,
30
+ )
31
+ if return_grad_fn:
32
+ return y, out.backward
33
+ return y
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+
8
+ class AllToAllOp(torch.autograd.Function):
9
+
10
+ @staticmethod
11
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
12
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
13
+
14
+ ctx.input_shape = x.shape
15
+ ctx.output_split_sizes = output_split_sizes
16
+ ctx.input_split_sizes = input_split_sizes
17
+ ctx.group = group
18
+ handle = dist.all_to_all_single(
19
+ out,
20
+ x,
21
+ output_split_sizes=output_split_sizes,
22
+ input_split_sizes=input_split_sizes,
23
+ group=group,
24
+ async_op=async_op,
25
+ )
26
+ return out, handle
27
+
28
+ @staticmethod
29
+ def backward(ctx, grad, _):
30
+ if ctx.needs_input_grad[0]:
31
+ out = torch.empty(
32
+ ctx.input_shape,
33
+ device=grad.device,
34
+ dtype=grad.dtype,
35
+ )
36
+ dist.all_to_all_single(
37
+ out,
38
+ grad,
39
+ output_split_sizes=ctx.input_split_sizes,
40
+ input_split_sizes=ctx.output_split_sizes,
41
+ group=ctx.group,
42
+ )
43
+ return out, None, None, None, None
44
+ return None, None, None, None, None
45
+
46
+
47
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
48
+ return AllToAllOp.apply(
49
+ x,
50
+ output_split_sizes,
51
+ input_split_sizes,
52
+ group,
53
+ async_op,
54
+ )
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import dataclasses
5
+ from functools import partial
6
+ from typing import Any, Callable, Optional, Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ import torch.nn.functional as F
11
+
12
+ import megablocks.grouped_gemm_util as grouped_gemm
13
+
14
+ # Type annotation for in-place Tensor initialization function.
15
+ InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
16
+
17
+ _ALLOWED_BITWIDTHS = (-1, 4, 8)
18
+
19
+ DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class Arguments:
24
+ # Model arguments.
25
+ hidden_size: int = 1024
26
+ ffn_hidden_size: int = 4096
27
+ num_layers: int = 1
28
+ bias: bool = True
29
+ return_bias: bool = True
30
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
31
+
32
+ # MoE arguments.
33
+ moe_num_experts: int = 1
34
+ moe_top_k: int = 1
35
+ moe_capacity_factor: int = 1
36
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
37
+ moe_loss_weight: float = 0.1
38
+ moe_jitter_eps: Optional[float] = None
39
+ moe_lbl_in_fp32: bool = False
40
+
41
+ # Parallelism arguments.
42
+ moe_expert_model_parallelism: bool = False
43
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
44
+ pipeline_model_parallel_size: int = 1
45
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
46
+
47
+ # Compute arguments.
48
+ memory_optimized_mlp: bool = False
49
+ mlp_type: str = 'mlp'
50
+ mlp_impl: str = 'sparse'
51
+
52
+ # Initialization arguments.
53
+ fp16: bool = True
54
+ bf16: bool = False
55
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
56
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
57
+ output_layer_init_method: InitFn = init_method
58
+
59
+ # Benchmarking arguments.
60
+ uniform_expert_assignment: bool = False
61
+
62
+ # shared expert arguments
63
+ shared_expert: bool = False # enable using shared expert
64
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
65
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
66
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
67
+ shared_expert_hidden_size: Optional[
68
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
69
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
70
+
71
+ # Router Z-loss arguments
72
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
73
+ moe_zloss_in_fp32: bool = False
74
+
75
+ def __post_init__(self):
76
+ # Sparse MLP is not supported with triton >=3.2.0
77
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
78
+ if self.__getattribute__('mlp_impl') == 'sparse':
79
+ try:
80
+ import triton
81
+ if triton.__version__ >= '3.2.0':
82
+ raise ValueError(
83
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
84
+ )
85
+ except ImportError:
86
+ raise ImportError('Triton is required for sparse MLP implementation')
87
+
88
+ if self.__getattribute__('mlp_impl') == 'grouped':
89
+ grouped_gemm.assert_grouped_gemm_is_available()
90
+
91
+ if self.shared_expert_hidden_size is None:
92
+ self.shared_expert_hidden_size = self.ffn_hidden_size
93
+
94
+
95
+ def from_megatron(megatron_args: Any):
96
+ args = Arguments()
97
+ for field in dataclasses.fields(args):
98
+ if hasattr(megatron_args, field.name):
99
+ setattr(args, field.name, getattr(megatron_args, field.name))
100
+ return args
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from megablocks.layers.arguments import Arguments
7
+
8
+
9
+ def dtype(args: Arguments):
10
+ if args.fp16:
11
+ return torch.float16
12
+ elif args.bf16:
13
+ return torch.bfloat16
14
+ return None
15
+
16
+
17
+ def cast_if_autocast_enabled(tensor):
18
+ if torch.is_autocast_enabled():
19
+ if tensor.device.type == 'cuda':
20
+ dtype = torch.get_autocast_gpu_dtype()
21
+ elif tensor.device.type == 'cpu':
22
+ dtype = torch.get_autocast_cpu_dtype()
23
+ else:
24
+ raise NotImplementedError()
25
+ return tensor.to(dtype=dtype)
26
+ return tensor
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Union
5
+
6
+ from megablocks.layers import glu, mlp
7
+ from megablocks.layers.arguments import Arguments
8
+
9
+ MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
10
+
11
+ _REGISTRY = {
12
+ 'mlp': {
13
+ 'grouped': mlp.GroupedMLP,
14
+ 'sparse': mlp.SparseMLP,
15
+ },
16
+ 'glu': {
17
+ 'grouped': glu.GroupedGLU,
18
+ 'sparse': glu.SparseGLU,
19
+ },
20
+ }
21
+
22
+
23
+ def get(args: Arguments) -> MlpType:
24
+ """Returns an MLP for use in a dMoE instance.
25
+
26
+ Uses the provided arguments to instantiate the appropriate
27
+ MLP instance. This only contains MLPs for use in dMoEs
28
+ (ie. only for the dropless versions of MoEs).
29
+
30
+ Args:
31
+ args: propagated Arguments dataclass.
32
+
33
+ Returns:
34
+ An instantiated MLP constructed using the input args.
35
+ """
36
+ if args.mlp_type not in _REGISTRY:
37
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
38
+
39
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
40
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
41
+
42
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import stk.ops
6
+ import torch
7
+ from stk import Matrix
8
+
9
+ import megablocks.ops as ops
10
+ # from megablocks.ops import ops
11
+ from megablocks.layers import common, dmlp_registry, moe, mpu
12
+ from megablocks.layers.arguments import Arguments
13
+
14
+
15
+ def promote_scalar(x):
16
+ return x.view(1) if not len(x.size()) else x
17
+
18
+
19
+ class ParallelDroplessMLP(moe.ParallelMLP):
20
+
21
+ def __init__(self, args: Arguments):
22
+ super(ParallelDroplessMLP, self).__init__(args)
23
+ self.hidden_size = args.hidden_size
24
+ self.ffn_hidden_size = mpu.features_per_rank(args)
25
+ self.blocking = 128
26
+ self.mlp = dmlp_registry.get(args)
27
+
28
+ # Calculate the number of bits needed to represent the column indices
29
+ # in the intermediate sparse matrix.
30
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
31
+ self.transpose_sort_end_bit = max(
32
+ int(np.ceil(np.log2(max_column_index))),
33
+ 1,
34
+ )
35
+
36
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
37
+ block_columns = size[1] // self.blocking
38
+
39
+ # Sort row indices by column indices to get the transposed matrix's
40
+ # column indices.
41
+ #
42
+ # NOTE: Our sort operation uses the same width indices as the input values.
43
+ # To avoid overflow when we have large activation matrices we cast to
44
+ # 32-bit before sorting.
45
+ _, gather_indices = ops.sort(
46
+ column_indices.int(),
47
+ self.transpose_sort_end_bit,
48
+ )
49
+
50
+ # There are a constant number of blocks in every row of the sparse matrix.
51
+ # A blocks offset is:
52
+ #
53
+ # row_index * blocks_per_row + column_index % blocks_per_row
54
+ #
55
+ # Once we have the block offsets ordered for transposition we can divide
56
+ # by blocks_per_row to get the transposed column indices.
57
+ column_indices_t = row_indices.gather(0, gather_indices.long())
58
+ block_offsets_t = gather_indices.int()
59
+
60
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
61
+ nnz_per_column = ops.histogram(column_indices, block_columns)
62
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
63
+ if nnz_per_column.dim() == 0:
64
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
65
+ nnz_per_column = nnz_per_column.unsqueeze(0)
66
+ offsets_t = torch.cat([zero, nnz_per_column])
67
+ return column_indices_t, offsets_t, block_offsets_t
68
+
69
+ def topology(self, x, padded_bins):
70
+ padded_tokens, _ = x.size()
71
+ assert padded_tokens % self.blocking == 0
72
+ if self.ffn_hidden_size % self.blocking != 0:
73
+ raise ValueError(
74
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
75
+ f'the block size {self.blocking}. Please update your configuration.',
76
+ )
77
+
78
+ # Offsets for the sparse matrix. All rows have the
79
+ # same number of nonzero blocks dictated by the
80
+ # dimensionality of a single expert.
81
+ block_rows = padded_tokens // self.blocking
82
+ blocks_per_row = self.ffn_hidden_size // self.blocking
83
+ offsets = torch.arange(
84
+ 0,
85
+ block_rows * blocks_per_row + 1,
86
+ blocks_per_row,
87
+ dtype=torch.int32,
88
+ device=x.device,
89
+ )
90
+
91
+ # Indices for the sparse matrix. The indices for
92
+ # the intermediate matrix are dynamic depending
93
+ # on the mapping of tokens to experts.
94
+ column_indices = ops.topology(
95
+ padded_bins,
96
+ self.blocking,
97
+ block_rows,
98
+ blocks_per_row,
99
+ )
100
+
101
+ # TODO(tgale): This is unused. Remove the need for this in stk.
102
+ # For now, use meta init to save the device memory.
103
+ data = torch.empty(
104
+ column_indices.numel(),
105
+ self.blocking,
106
+ self.blocking,
107
+ dtype=common.dtype(self.args),
108
+ device='meta',
109
+ )
110
+ shape = (
111
+ padded_tokens,
112
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
113
+ )
114
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
115
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
116
+ shape,
117
+ row_indices,
118
+ column_indices,
119
+ offsets,
120
+ )
121
+ return stk.Matrix(
122
+ shape,
123
+ data,
124
+ row_indices,
125
+ column_indices,
126
+ offsets,
127
+ column_indices_t,
128
+ offsets_t,
129
+ block_offsets_t,
130
+ )
131
+
132
+ def indices_and_padded_bins(self, top_experts):
133
+ # Sort the expert ids to produce the scatter/gather
134
+ # indices for the permutation.
135
+ top_experts = top_experts.int()
136
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
137
+
138
+ # Histogram the expert ids to identify the number of
139
+ # tokens routed to each expert.
140
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
141
+
142
+ # Round the token counts up to the block size used in
143
+ # the matrix muliplications. Caculate the starting
144
+ # position of each bin.
145
+ padded_tokens_per_expert = ops.round_up(
146
+ tokens_per_expert,
147
+ self.blocking,
148
+ )
149
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
150
+ padded_bins = promote_scalar(padded_bins)
151
+
152
+ # Calculate the bin bounds for the sorted tokens.
153
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
154
+ bins = promote_scalar(bins)
155
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
156
+
157
+ def sparse_forward_once(self, x, expert_weights, top_experts):
158
+ # x: [sl, bs, hs]
159
+ # expert_weights: [sl * bs, top-k]
160
+ # top_experts: [sl * bs, top-k]
161
+ expert_weights = expert_weights.flatten()
162
+ top_experts = top_experts.flatten()
163
+ with torch.no_grad():
164
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
165
+
166
+ # Route the tokens for MoE computation.
167
+ x = x.view(-1, x.shape[-1])
168
+ x = ops.padded_gather(
169
+ x,
170
+ indices,
171
+ bin_ids,
172
+ bins,
173
+ padded_bins,
174
+ self.top_k,
175
+ )
176
+
177
+ # Create the sparse matrix topology.
178
+ with torch.no_grad():
179
+ topo = self.topology(x, padded_bins)
180
+
181
+ # Perform the expert computation.
182
+ x = self.mlp(x, topo)
183
+
184
+ # Un-route the data for the MoE output.
185
+ x = ops.padded_scatter(
186
+ x,
187
+ indices,
188
+ bin_ids,
189
+ expert_weights,
190
+ bins,
191
+ padded_bins,
192
+ self.top_k,
193
+ )
194
+ return x, tokens_per_expert
195
+
196
+ # For use in the base-class parallel_forward_once.
197
+ def sparse_permute_and_compute(
198
+ self,
199
+ x,
200
+ tokens_per_expert,
201
+ indices,
202
+ bin_ids,
203
+ expert_weights,
204
+ bins,
205
+ expert_capactiy, # unused
206
+ top_k,
207
+ ):
208
+
209
+ # Round the token counts up to the block size used in the matrix
210
+ # multiplication. Calculate the starting position of each bin.
211
+ padded_tokens_per_expert = ops.round_up(
212
+ tokens_per_expert,
213
+ self.blocking,
214
+ )
215
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
216
+ padded_bins = promote_scalar(padded_bins)
217
+
218
+ # Route the tokens for MoE computation.
219
+ x = x.view(-1, x.shape[-1])
220
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
221
+
222
+ # Create the sparse matrix topology.
223
+ with torch.no_grad():
224
+ topo = self.topology(x, padded_bins)
225
+
226
+ # Perform the expert computation.
227
+ x = self.mlp(x, topo)
228
+
229
+ # Un-route the data for the MoE output.
230
+ return ops.padded_scatter(
231
+ x,
232
+ indices,
233
+ bin_ids,
234
+ expert_weights,
235
+ bins,
236
+ padded_bins,
237
+ top_k,
238
+ )
239
+
240
+ def grouped_forward_once(self, x, expert_weights, top_experts):
241
+ # x: [sl, bs, hs]
242
+ # expert_weights: [sl * bs, top-k]
243
+ # top_experts: [sl * bs, top-k]
244
+ expert_weights = expert_weights.flatten()
245
+ top_experts = top_experts.flatten()
246
+ with torch.no_grad():
247
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
248
+
249
+ out = self.grouped_permute_and_compute(
250
+ x,
251
+ tokens_per_expert,
252
+ indices,
253
+ bin_ids,
254
+ expert_weights,
255
+ bins,
256
+ -1, # unused
257
+ self.args.moe_top_k,
258
+ )
259
+ return out, tokens_per_expert
260
+
261
+ def grouped_permute_and_compute(
262
+ self,
263
+ x,
264
+ tokens_per_expert,
265
+ indices,
266
+ bin_ids,
267
+ expert_weights,
268
+ bins,
269
+ expert_capactiy, # unused
270
+ top_k,
271
+ ):
272
+
273
+ # Route the tokens for MoE computation.
274
+ x = x.view(-1, x.shape[-1])
275
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
276
+
277
+ # Perform the expert computation.
278
+ x = self.mlp(x, tokens_per_expert)
279
+
280
+ # Un-route the data for the MoE output.
281
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
282
+
283
+ def forward_once(self, x, expert_weights, top_experts):
284
+ if self.args.mlp_impl == 'sparse':
285
+ return self.sparse_forward_once(x, expert_weights, top_experts)
286
+ else:
287
+ return self.grouped_forward_once(x, expert_weights, top_experts)
288
+
289
+ def permute_and_compute(
290
+ self,
291
+ x,
292
+ tokens_per_expert,
293
+ indices,
294
+ bin_ids,
295
+ expert_weights,
296
+ bins,
297
+ expert_capactiy,
298
+ top_k,
299
+ ):
300
+ if self.args.mlp_impl == 'sparse':
301
+ return self.sparse_permute_and_compute(
302
+ x,
303
+ tokens_per_expert,
304
+ indices,
305
+ bin_ids,
306
+ expert_weights,
307
+ bins,
308
+ expert_capactiy,
309
+ top_k,
310
+ )
311
+ else:
312
+ return self.grouped_permute_and_compute(
313
+ x,
314
+ tokens_per_expert,
315
+ indices,
316
+ bin_ids,
317
+ expert_weights,
318
+ bins,
319
+ expert_capactiy,
320
+ top_k,
321
+ )
322
+
323
+
324
+ class dMoE(moe.MoE):
325
+
326
+ def _init_experts_mlp(self, args: Arguments):
327
+ return ParallelDroplessMLP(args)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import stk
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ @torch.jit.script
10
+ def _gelu_backward_inplace(g, x):
11
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
12
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
13
+ return g.mul_(ff)
14
+
15
+
16
+ def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
17
+ # NOTE: The two sparse matrices must have the same topology.
18
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
19
+ return stk.Matrix(
20
+ x.size(),
21
+ _gelu_backward_inplace(grad.data, x.data),
22
+ x.row_indices,
23
+ x.column_indices,
24
+ x.offsets,
25
+ x.column_indices_t,
26
+ x.offsets_t,
27
+ x.block_offsets_t,
28
+ )
29
+ return _gelu_backward_inplace(grad, x)
30
+
31
+
32
+ def gelu(x: stk.Matrix):
33
+ assert isinstance(x, stk.Matrix)
34
+ return stk.Matrix(
35
+ x.size(),
36
+ F.gelu(x.data, approximate='tanh'),
37
+ x.row_indices,
38
+ x.column_indices,
39
+ x.offsets,
40
+ x.column_indices_t,
41
+ x.offsets_t,
42
+ x.block_offsets_t,
43
+ )
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import stk.ops
5
+ import torch
6
+
7
+ from megablocks import grouped_gemm_util as gg
8
+ from megablocks.layers import common, mpu
9
+ from megablocks.layers.activation_fn import act_fn
10
+ from megablocks.layers.arguments import Arguments
11
+ from megablocks.layers.mlp import (
12
+ SharedMLP,
13
+ SparseMLP,
14
+ create_dmoe_expert_weights,
15
+ resolve_dtensor,
16
+ )
17
+
18
+
19
+ class SparseGLU(SparseMLP):
20
+
21
+ def __init__(self, args: Arguments):
22
+ super().__init__(args)
23
+ self.v1 = torch.nn.Parameter(
24
+ torch.empty(
25
+ self._num_rows_per_rank,
26
+ args.hidden_size,
27
+ device=args.device,
28
+ dtype=common.dtype(args),
29
+ ),
30
+ )
31
+ with torch.no_grad():
32
+ self.v1.copy_(
33
+ create_dmoe_expert_weights(
34
+ args,
35
+ args.moe_num_experts,
36
+ args.ffn_hidden_size,
37
+ args.hidden_size,
38
+ args.init_method,
39
+ ),
40
+ )
41
+
42
+ mpu.set_expert_model_parallel_attributes(
43
+ self.v1,
44
+ self._should_set_parallelism_attribute,
45
+ )
46
+
47
+ def forward(self, x, topo):
48
+ if self.args.memory_optimized_mlp:
49
+ raise NotImplementedError(
50
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
51
+ )
52
+
53
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
54
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
55
+
56
+ # Compute the GLU.
57
+ x1 = stk.ops.sdd(x, w1.t(), topo)
58
+ x2 = stk.ops.sdd(x, v1.t(), topo)
59
+
60
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
61
+ x1 = stk.ops.mul(activation_fn_out, x2)
62
+
63
+ return stk.ops.dsd(x1, w2)
64
+
65
+
66
+ class MemoryOptimizedGroupedGLU(torch.autograd.Function):
67
+ """GroupedMLP with manually scheduled memory reuse."""
68
+
69
+ @staticmethod
70
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
71
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
72
+ # Cast inputs using ctx dtype from AMP
73
+ if ctx._fwd_used_autocast:
74
+ x = x.to(ctx._dtype)
75
+ w1 = w1.to(ctx._dtype)
76
+ v1 = v1.to(ctx._dtype)
77
+ w2 = w2.to(ctx._dtype)
78
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
79
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
80
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
81
+
82
+ # Layer 0: x @ w1.t().
83
+ assert gg.backend is not None
84
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
85
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
86
+
87
+ # GeLU.
88
+ activation_fn_out = activation_fn(sdd_out) * v1_out
89
+
90
+ # Layer 1: x @ w2.
91
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
92
+
93
+ # NOTE: Save the input to the layer and the activation_fn input for
94
+ # gradient computation. We'll re-compute the activation_fn forward
95
+ # pass in the backward pass to avoid materializing another
96
+ # intermediate.
97
+ ctx.x_shape = x.shape
98
+ ctx.sdd_out_shape = sdd_out.shape
99
+ ctx.dtype = x.dtype
100
+ ctx.activation_fn = activation_fn
101
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
102
+ return dsd_out
103
+
104
+ @staticmethod
105
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
106
+ def backward(ctx, ddsd_out):
107
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
108
+ raise ValueError('Expected all MLP inputs to need grad.')
109
+
110
+ # Unpack saved tensors
111
+ # dtype = ctx.dtype
112
+ saved_tensors = ctx.saved_tensors
113
+ w1, v1, w2 = saved_tensors[:3]
114
+ batch_sizes = saved_tensors[3]
115
+ x = saved_tensors[4]
116
+ sdd_out, v1_out = saved_tensors[5:7]
117
+
118
+ # Rematerialize activation_fn output.
119
+ activation_fn = ctx.activation_fn
120
+ with torch.set_grad_enabled(True):
121
+ sdd_out.requires_grad = True
122
+ v1_out.requires_grad = True
123
+ activation_fn_out = activation_fn(sdd_out) * v1_out
124
+ activation_grad_fn = activation_fn_out.backward
125
+
126
+ # Compute dw2 with recomputed activation_fn output.
127
+ assert gg.backend is not None
128
+ dw2 = gg.backend.gmm(
129
+ activation_fn_out,
130
+ ddsd_out,
131
+ batch_sizes,
132
+ trans_a=True,
133
+ )
134
+
135
+ # Compute dactivation_fn_out.
136
+ #
137
+ # NOTE: We reuse the activation_fn_out allocation.
138
+ dactivation_fn_out = activation_fn_out
139
+ gg.backend.gmm(
140
+ ddsd_out,
141
+ w2,
142
+ batch_sizes,
143
+ trans_b=True,
144
+ c=dactivation_fn_out,
145
+ )
146
+
147
+ # Compute dsdd_out.
148
+ #
149
+ # NOTE: This reuses the dactivation_fn_out allocation.
150
+ assert activation_grad_fn is not None
151
+ activation_grad_fn(dactivation_fn_out)
152
+ dsdd_out = sdd_out.grad
153
+ dv1_out = v1_out.grad
154
+
155
+ # Compute dw1.
156
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
157
+
158
+ # Compute dv1.
159
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
160
+
161
+ # Compute dx.
162
+ #
163
+ # NOTE: This reuses the ddsd_out allocation.
164
+ dx = ddsd_out
165
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
166
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
167
+ return dx, dw1, dv1, dw2, None, None
168
+
169
+
170
+ memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
171
+
172
+
173
+ class GroupedGLU(SparseGLU):
174
+
175
+ def forward(self, x, tokens_per_expert):
176
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
177
+ w1, v1, w2 = (
178
+ self.scale_grad(self.w1),
179
+ self.scale_grad(self.v1),
180
+ self.scale_grad(self.w2),
181
+ )
182
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
183
+
184
+ # Re-shape the weights for the grouped GEMMs.
185
+ ne = mpu.experts_per_rank(self.args)
186
+ w1 = w1.view(ne, -1, self.args.hidden_size)
187
+ v1 = v1.view(ne, -1, self.args.hidden_size)
188
+ w2 = w2.view(ne, -1, self.args.hidden_size)
189
+
190
+ if self.args.memory_optimized_mlp:
191
+ return memory_optimized_grouped_glu(
192
+ x,
193
+ w1,
194
+ v1,
195
+ w2,
196
+ batch_sizes,
197
+ self.args.activation_fn,
198
+ )
199
+
200
+ # Compute the MLP.
201
+ assert gg.ops is not None
202
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
203
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
204
+ x1 = self.args.activation_fn(x1) * x2
205
+ return gg.ops.gmm(x1, w2, batch_sizes)
206
+
207
+
208
+ class SharedGLU(SharedMLP):
209
+ """GPU for shared expert.
210
+
211
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
212
+ """
213
+
214
+ def __init__(self, args: Arguments):
215
+ super().__init__(args)
216
+ self.gate_proj = args.fc_cls(
217
+ args.hidden_size,
218
+ self.args.shared_expert_hidden_size,
219
+ **self.fc_kwargs,
220
+ )
221
+
222
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
223
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import gc
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ from megablocks.layers import arguments, dmoe
10
+
11
+ _TESTS = ((8, 2048, 4096, 4096, 32, 4),)
12
+
13
+
14
+ def get_tensors():
15
+ ptrs = set()
16
+ out = []
17
+ for obj in gc.get_objects():
18
+ if torch.is_tensor(obj):
19
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
20
+ continue
21
+ out.append(obj)
22
+ ptrs.add(obj.data_ptr())
23
+ return out
24
+
25
+
26
+ def test_memory(
27
+ group,
28
+ batch_size,
29
+ sequence_length,
30
+ hidden_size,
31
+ ffn_hidden_size,
32
+ num_experts,
33
+ top_k,
34
+ ):
35
+ args = arguments.Arguments(
36
+ hidden_size=hidden_size,
37
+ ffn_hidden_size=ffn_hidden_size,
38
+ moe_num_experts=num_experts,
39
+ moe_top_k=top_k,
40
+ moe_expert_model_parallelism=True,
41
+ expert_parallel_group=group,
42
+ fp16=False,
43
+ bf16=True,
44
+ device=torch.cuda.current_device(),
45
+ )
46
+ layer = dmoe.dMoE(args).cuda()
47
+
48
+ x = torch.randn((batch_size, sequence_length, hidden_size),
49
+ device=torch.cuda.current_device(),
50
+ dtype=torch.bfloat16).requires_grad_(True)
51
+ torch.cuda.empty_cache()
52
+
53
+ # Run forward + backward.
54
+ # with torch.autograd.detect_anomaly():
55
+ out, _ = layer(x)
56
+ out.mean().backward()
57
+
58
+ # Report peak memory.
59
+ mem = torch.cuda.max_memory_allocated()
60
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
61
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
62
+
63
+ # Calculate weight and gradient memory usage.
64
+ weight_memory = 2 * (
65
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
66
+ )
67
+
68
+ def grad_numel(x):
69
+ if x.grad is not None:
70
+ return x.grad.numel()
71
+ return 0
72
+
73
+ grad_memory = 2 * (
74
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
75
+ )
76
+ weight_memory += grad_memory
77
+
78
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
79
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
80
+
81
+ # Manually calculate GPU memory usage from the garbage
82
+ # collector.
83
+ gc.collect()
84
+ total = 0
85
+ tensors = get_tensors()
86
+ tensors = sorted(tensors, key=lambda x: -x.numel())
87
+ for i, t in enumerate(tensors):
88
+ total += t.numel()
89
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
90
+ del tensors
91
+
92
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
93
+
94
+
95
+ if __name__ == '__main__':
96
+ assert dist.is_available()
97
+ group = dist.init_process_group(backend='nccl')
98
+ local_rank = dist.get_rank(group)
99
+ torch.cuda.set_device(local_rank)
100
+
101
+ for args in _TESTS:
102
+ test_memory(group, *args)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ import stk
7
+ import stk.backend.triton_kernels
8
+ import stk.ops
9
+ import torch
10
+ from packaging import version
11
+
12
+ from megablocks import grouped_gemm_util as gg
13
+ from megablocks.layers import common, gelu, mpu
14
+ from megablocks.layers.activation_fn import act_fn
15
+ from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
16
+
17
+
18
+ class ScaleGradient(torch.autograd.Function):
19
+
20
+ @staticmethod
21
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
22
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
23
+ ctx.scale = scale
24
+ return x
25
+
26
+ @staticmethod
27
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
28
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
29
+ return grad * ctx.scale, None
30
+
31
+
32
+ scale_gradient = ScaleGradient.apply
33
+
34
+
35
+ def resolve_dtensor(weight: torch.Tensor):
36
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
37
+ from torch.distributed._tensor import DTensor
38
+ if isinstance(weight, DTensor):
39
+ return weight.to_local()
40
+ return weight
41
+
42
+
43
+ def create_moe_expert_weights(
44
+ args: Arguments,
45
+ num_experts: int,
46
+ ffn_hidden_size: int,
47
+ hidden_size: int,
48
+ init_method: InitFn,
49
+ ):
50
+ # Create the entire weight matrix such that the sampled weights will
51
+ # not vary between data parallelism and expert model parallelism for
52
+ # the same random seed.
53
+ master_weights = torch.empty(
54
+ num_experts,
55
+ ffn_hidden_size,
56
+ hidden_size,
57
+ device=args.device,
58
+ dtype=common.dtype(args),
59
+ )
60
+ init_method(master_weights)
61
+
62
+ if not args.moe_expert_model_parallelism:
63
+ return master_weights
64
+
65
+ # Calculate the amount of sharding in each dimension.
66
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
67
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
68
+
69
+ # Calculate the experts per rank.
70
+ #
71
+ # NOTE: We assign ranks to be expert parallel before going
72
+ # tensor parallel.
73
+ rank = mpu.get_expert_parallel_rank(args)
74
+ expert_rank = rank % expert_sharding_degree
75
+ num_experts_per_rank = num_experts // expert_sharding_degree
76
+ start_expert = expert_rank * num_experts_per_rank
77
+ end_expert = (expert_rank + 1) * num_experts_per_rank
78
+
79
+ # Calculate the rows per rank.
80
+ row_rank = rank // expert_sharding_degree
81
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
82
+ start_row = row_rank * num_rows_per_rank
83
+ end_row = (row_rank + 1) * num_rows_per_rank
84
+
85
+ # Slice the weight matrix to get the chunk for this rank.
86
+ with torch.no_grad():
87
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
88
+ return weights
89
+
90
+
91
+ class MLP(torch.nn.Module):
92
+
93
+ def __init__(self, args: Arguments):
94
+ super().__init__()
95
+ self.args = args
96
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
97
+ experts_per_rank = mpu.experts_per_rank(args)
98
+
99
+ self.w1 = torch.nn.Parameter(
100
+ torch.empty(
101
+ experts_per_rank,
102
+ args.hidden_size,
103
+ mpu.features_per_rank(args),
104
+ device=args.device,
105
+ dtype=common.dtype(args),
106
+ ),
107
+ )
108
+ self.w2 = torch.nn.Parameter(
109
+ torch.empty(
110
+ experts_per_rank,
111
+ mpu.features_per_rank(args),
112
+ args.hidden_size,
113
+ device=args.device,
114
+ dtype=common.dtype(args),
115
+ ),
116
+ )
117
+ mpu.set_expert_model_parallel_attributes(
118
+ self.w1,
119
+ args.moe_expert_model_parallelism,
120
+ )
121
+ mpu.set_expert_model_parallel_attributes(
122
+ self.w2,
123
+ args.moe_expert_model_parallelism,
124
+ )
125
+
126
+ # Initialize the parameters for the MLP.
127
+ #
128
+ # NOTE: It is important that we create the weight tensors prior
129
+ # to creating the master weights and slicing our the piece for
130
+ # this rank. If the master weights are created first the PyTorch
131
+ # caching allocator appears to use the same memory block for these
132
+ # and the slice which causes large increases in our peak memory
133
+ # usage.
134
+ with torch.no_grad():
135
+ w1 = create_moe_expert_weights(
136
+ args,
137
+ args.moe_num_experts,
138
+ args.ffn_hidden_size,
139
+ args.hidden_size,
140
+ args.init_method,
141
+ )
142
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
143
+ self.w2.copy_(
144
+ create_moe_expert_weights(
145
+ args,
146
+ args.moe_num_experts,
147
+ args.ffn_hidden_size,
148
+ args.hidden_size,
149
+ args.output_layer_init_method,
150
+ ),
151
+ )
152
+
153
+ self.gradient_scale = None
154
+ if self.args.moe_expert_model_parallelism:
155
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
156
+
157
+ def scale_grad(self, w):
158
+ if self.gradient_scale is None:
159
+ return w
160
+ return scale_gradient(w, self.gradient_scale)
161
+
162
+ def forward(self, x):
163
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
164
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
165
+ x = torch.bmm(x, w1)
166
+ x = self.args.activation_fn(x)
167
+ return torch.bmm(x, w2)
168
+
169
+
170
+ def create_dmoe_expert_weights(
171
+ args: Arguments,
172
+ num_experts: int,
173
+ rows: int,
174
+ columns: int,
175
+ init_method: InitFn,
176
+ ):
177
+ weights = create_moe_expert_weights(
178
+ args,
179
+ num_experts,
180
+ rows,
181
+ columns,
182
+ init_method,
183
+ )
184
+ return weights.view([-1, columns])
185
+
186
+
187
+ class MemoryOptimizedMLP(torch.autograd.Function):
188
+ """Sparse MLP with manually scheduled memory reuse."""
189
+
190
+ @staticmethod
191
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
192
+ def forward(ctx, x, w1, w2, topo, activation_fn):
193
+ # Cast inputs using ctx dtype from AMP
194
+ if ctx._fwd_used_autocast:
195
+ x = x.to(ctx._dtype)
196
+ w1 = w1.to(ctx._dtype)
197
+ w2 = w2.to(ctx._dtype)
198
+ # x: [m, k], w1: [n, k], w2: [n, k]
199
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
200
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
201
+
202
+ topo_tensors = (
203
+ topo.row_indices,
204
+ topo.column_indices,
205
+ topo.offsets,
206
+ topo.column_indices_t,
207
+ topo.offsets_t,
208
+ topo.block_offsets_t,
209
+ )
210
+
211
+ # Layer 0: x @ w1.t().
212
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
213
+
214
+ # GeLU.
215
+ activation_fn_out = act_fn(sdd_out, activation_fn)
216
+
217
+ # Layer 1: x @ w2.
218
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
219
+
220
+ # NOTE: Save the input to the layer and the activation_fn input for
221
+ # gradient computation. We'll re-compute the activation_fn forward
222
+ # pass in the backward pass to avoid materializing another
223
+ # intermediate.
224
+ ctx.shape = topo.shape
225
+ ctx.x_shape = x.shape
226
+ ctx.sdd_out_shape = sdd_out.data.shape
227
+ ctx.dtype = x.dtype
228
+ ctx.activation_fn = activation_fn
229
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
230
+ return dsd_out
231
+
232
+ @staticmethod
233
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
234
+ def backward(ctx, ddsd_out):
235
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
236
+ raise ValueError('Expected all MLP inputs to need grad.')
237
+
238
+ # unpack saved tensors
239
+ # dtype = ctx.dtype
240
+ saved_tensors = ctx.saved_tensors
241
+ w1, w2 = saved_tensors[:2]
242
+ topo_tensors = saved_tensors[2:8]
243
+ x = saved_tensors[8]
244
+ sdd_out_data = saved_tensors[9]
245
+
246
+ # rematerialize activation function output
247
+ activation_fn = ctx.activation_fn
248
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
249
+ activation_fn_out, activation_grad_fn = act_fn(
250
+ sdd_out,
251
+ activation_fn,
252
+ return_grad_fn=True,
253
+ )
254
+
255
+ # Compute dw2 with recomputed activation_fn output.
256
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
257
+
258
+ # Compute dactivation_fn_out.
259
+ #
260
+ # NOTE: We reuse the activation_fn_out allocation.
261
+ dactivation_fn_out = activation_fn_out
262
+ stk.backend.triton_kernels.sdd(
263
+ ddsd_out,
264
+ w2.t(),
265
+ dactivation_fn_out.shape,
266
+ dactivation_fn_out.data,
267
+ dactivation_fn_out.offsets,
268
+ dactivation_fn_out.row_indices,
269
+ dactivation_fn_out.column_indices,
270
+ )
271
+
272
+ # Compute dsdd_out.
273
+ #
274
+ # NOTE: This reuses the dactivation_fn_out allocation.
275
+ if activation_fn is DEFAULT_ACTIVATION_FN:
276
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
277
+ else:
278
+ assert activation_grad_fn is not None
279
+ activation_grad_fn(dactivation_fn_out.data)
280
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
281
+
282
+ # Compute dw1.
283
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
284
+
285
+ # Compute dx.
286
+ #
287
+ # NOTE: This reuses the ddsd_out allocation.
288
+ stk.backend.triton_kernels.dsd(
289
+ dsdd_out.shape,
290
+ dsdd_out.data,
291
+ dsdd_out.offsets,
292
+ dsdd_out.row_indices,
293
+ dsdd_out.column_indices,
294
+ dsdd_out.offsets_t,
295
+ dsdd_out.column_indices_t,
296
+ dsdd_out.block_offsets_t,
297
+ False,
298
+ w1,
299
+ ddsd_out,
300
+ )
301
+ dx = ddsd_out
302
+ return dx, dw1, dw2, None, None
303
+
304
+
305
+ memory_optimized_mlp = MemoryOptimizedMLP.apply
306
+
307
+
308
+ class SparseMLP(torch.nn.Module):
309
+
310
+ def __init__(self, args: Arguments):
311
+ super().__init__()
312
+ self.args = args
313
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
314
+
315
+ self.w1 = torch.nn.Parameter(
316
+ torch.empty(
317
+ self._num_rows_per_rank,
318
+ args.hidden_size,
319
+ device=args.device,
320
+ dtype=common.dtype(args),
321
+ ),
322
+ )
323
+ self.w2 = torch.nn.Parameter(
324
+ torch.empty(
325
+ self._num_rows_per_rank,
326
+ args.hidden_size,
327
+ device=args.device,
328
+ dtype=common.dtype(args),
329
+ ),
330
+ )
331
+
332
+ # Initialize the parameters for the MLP.
333
+ #
334
+ # NOTE: It is important that we create the weight tensors prior
335
+ # to creating the master weights and slicing our the piece for
336
+ # this rank. If the master weights are created first the PyTorch
337
+ # caching allocator appears to use the same memory block for these
338
+ # and the slice which causes large increases in our peak memory
339
+ # usage.
340
+ with torch.no_grad():
341
+ self.w1.copy_(
342
+ create_dmoe_expert_weights(
343
+ args,
344
+ args.moe_num_experts,
345
+ args.ffn_hidden_size,
346
+ args.hidden_size,
347
+ args.init_method,
348
+ ),
349
+ )
350
+ self.w2.copy_(
351
+ create_dmoe_expert_weights(
352
+ args,
353
+ args.moe_num_experts,
354
+ args.ffn_hidden_size,
355
+ args.hidden_size,
356
+ args.output_layer_init_method,
357
+ ),
358
+ )
359
+
360
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
361
+ mpu.set_expert_model_parallel_attributes(
362
+ self.w1,
363
+ self._should_set_parallelism_attribute,
364
+ )
365
+ mpu.set_expert_model_parallel_attributes(
366
+ self.w2,
367
+ self._should_set_parallelism_attribute,
368
+ )
369
+
370
+ self.gradient_scale = None
371
+ if self.args.moe_expert_model_parallelism:
372
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
373
+
374
+ def scale_grad(self, w):
375
+ if self.gradient_scale is None:
376
+ return w
377
+ return scale_gradient(w, self.gradient_scale)
378
+
379
+ def forward(self, x, topo):
380
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
381
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
382
+ if self.args.memory_optimized_mlp:
383
+ return memory_optimized_mlp(
384
+ x,
385
+ w1,
386
+ w2,
387
+ topo,
388
+ self.args.activation_fn,
389
+ )
390
+
391
+ # Compute the MLP.
392
+ x = stk.ops.sdd(x, w1.t(), topo)
393
+ activation_fn_out = act_fn(x, self.args.activation_fn)
394
+ return stk.ops.dsd(activation_fn_out, w2)
395
+
396
+
397
+ class MemoryOptimizedGroupedMLP(torch.autograd.Function):
398
+ """GroupedMLP with manually scheduled memory reuse."""
399
+
400
+ @staticmethod
401
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
402
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
403
+ # Cast inputs using ctx dtype from AMP
404
+ if ctx._fwd_used_autocast:
405
+ x = x.to(ctx._dtype)
406
+ w1 = w1.to(ctx._dtype)
407
+ w2 = w2.to(ctx._dtype)
408
+ # x: [m, k], w1: [n, k], w2: [n, k]
409
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
410
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
411
+
412
+ # Layer 0: x @ w1.t().
413
+ assert gg.backend is not None
414
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
415
+
416
+ # activation_fn
417
+ activation_fn_out = activation_fn(sdd_out)
418
+
419
+ # Layer 1: x @ w2.
420
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
421
+
422
+ # NOTE: Save the input to the layer and the activation_fn input for
423
+ # gradient computation. We'll re-compute the activation_fn forward
424
+ # pass in the backward pass to avoid materializing another
425
+ # intermediate.
426
+ ctx.x_shape = x.shape
427
+ ctx.sdd_out_shape = sdd_out.shape
428
+ ctx.dtype = x.dtype
429
+ ctx.activation_fn = activation_fn
430
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
431
+ return dsd_out
432
+
433
+ @staticmethod
434
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
435
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
436
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
437
+ raise ValueError('Expected all MLP inputs to need grad.')
438
+
439
+ # Unpack saved tensors
440
+ # dtype = ctx.dtype
441
+ saved_tensors = ctx.saved_tensors
442
+ w1, w2 = saved_tensors[:2]
443
+ batch_sizes = saved_tensors[2]
444
+ x = saved_tensors[3]
445
+ sdd_out = saved_tensors[4]
446
+
447
+ # Rematerialize activation_fn output.
448
+ activation_fn = ctx.activation_fn
449
+ with torch.set_grad_enabled(True):
450
+ sdd_out.requires_grad = True
451
+ activation_fn_out = activation_fn(sdd_out)
452
+ activation_grad_fn = activation_fn_out.backward
453
+
454
+ # Compute dw2 with recomputed activation_fn output.
455
+ assert gg.backend is not None
456
+ dw2 = gg.backend.gmm(
457
+ activation_fn_out,
458
+ ddsd_out,
459
+ batch_sizes,
460
+ trans_a=True,
461
+ )
462
+
463
+ # Compute dactivation_fn_out.
464
+ #
465
+ # NOTE: We reuse the activation_fn_out allocation.
466
+ dactivation_fn_out = activation_fn_out
467
+ gg.backend.gmm(
468
+ ddsd_out,
469
+ w2,
470
+ batch_sizes,
471
+ trans_b=True,
472
+ c=dactivation_fn_out,
473
+ )
474
+
475
+ # Compute dsdd_out.
476
+ #
477
+ # NOTE: This reuses the dactivation_fn_out allocation.
478
+ if activation_fn is DEFAULT_ACTIVATION_FN:
479
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
480
+ else:
481
+ assert activation_grad_fn is not None
482
+ activation_grad_fn(dactivation_fn_out)
483
+ dsdd_out = sdd_out.grad
484
+
485
+ # Compute dw1.
486
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
487
+
488
+ # Compute dx.
489
+ #
490
+ # NOTE: This reuses the ddsd_out allocation.
491
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
492
+ dx = ddsd_out
493
+ return dx, dw1, dw2, None, None
494
+
495
+
496
+ memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
497
+
498
+
499
+ class GroupedMLP(SparseMLP):
500
+
501
+ def forward(self, x, tokens_per_expert):
502
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
503
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
504
+
505
+ # Re-shape the weights for the grouped GEMMs.
506
+ ne = mpu.experts_per_rank(self.args)
507
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
508
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
509
+
510
+ if self.args.memory_optimized_mlp:
511
+ return memory_optimized_grouped_mlp(
512
+ x,
513
+ w1,
514
+ w2,
515
+ batch_sizes,
516
+ self.args.activation_fn,
517
+ )
518
+
519
+ # Compute the MLP.
520
+ assert gg.ops is not None
521
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
522
+ x = self.args.activation_fn(x)
523
+ return gg.ops.gmm(x, w2, batch_sizes)
524
+
525
+
526
+ class SharedMLP(torch.nn.Module):
527
+ """MLP for shared expert.
528
+
529
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
530
+ """
531
+
532
+ def __init__(self, args: Arguments):
533
+ super().__init__()
534
+ self.args = args
535
+ self.fc_kwargs: dict[str, Any] = {
536
+ 'bias': args.bias,
537
+ 'device': args.device,
538
+ }
539
+ self.fc_kwargs.update(args.fc_kwargs)
540
+
541
+ self.up_proj = args.fc_cls(
542
+ args.hidden_size,
543
+ args.shared_expert_hidden_size,
544
+ **self.fc_kwargs,
545
+ )
546
+ self.act = args.activation_fn
547
+ self.down_proj = args.fc_cls(
548
+ args.shared_expert_hidden_size,
549
+ args.hidden_size,
550
+ **self.fc_kwargs,
551
+ )
552
+ self.down_proj._is_residual = True # a flag for llm-foundry init
553
+
554
+ def add_experts_sharedexpert(
555
+ self,
556
+ shared_expert_out: torch.Tensor,
557
+ expert_out: torch.Tensor,
558
+ ) -> torch.Tensor:
559
+ # Helper function to add expert output to shared expert output
560
+ # with optional weighted sum.
561
+ if self.args.shared_expert_weighted_sum:
562
+ # enable using weighted sum for shared expert output
563
+ # wieghted by number of experts used
564
+ t_experts = self.args.moe_top_k + 1
565
+ sh_mlp_out = shared_expert_out / t_experts
566
+ return sh_mlp_out.add(
567
+ expert_out,
568
+ alpha=(self.args.moe_top_k / t_experts),
569
+ )
570
+
571
+ return shared_expert_out + expert_out
572
+
573
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
574
+ return self.down_proj(self.act(self.up_proj(x)))
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ import megablocks.ops as ops
10
+ from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
11
+ from megablocks.layers.all_to_all import all_to_all
12
+ from megablocks.layers.arguments import Arguments
13
+
14
+ _LOAD_BALANCING_LOSS = []
15
+
16
+
17
+ def save_load_balancing_loss(loss):
18
+ global _LOAD_BALANCING_LOSS
19
+ _LOAD_BALANCING_LOSS.append(loss)
20
+
21
+
22
+ def get_load_balancing_loss():
23
+ global _LOAD_BALANCING_LOSS
24
+ return _LOAD_BALANCING_LOSS
25
+
26
+
27
+ def clear_load_balancing_loss():
28
+ global _LOAD_BALANCING_LOSS
29
+ _LOAD_BALANCING_LOSS.clear()
30
+
31
+
32
+ def batched_load_balancing_loss(args: Arguments):
33
+ if args.moe_loss_weight == 0:
34
+ return 0.0
35
+
36
+ # tokens_per_expert[i].shape = (num_experts)
37
+ # expert_scores[i].shape = (tokens, num_experts)
38
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
39
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
40
+ if args.num_layers_per_virtual_pipeline_stage is not None:
41
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
42
+
43
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
44
+ raise ValueError(
45
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
46
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
47
+ f'{args.num_layers}\npipeline_model_parallel_size = '
48
+ f'{args.pipeline_model_parallel_size}\n'
49
+ 'num_layers_per_virtual_pipeline_stage'
50
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
51
+ )
52
+ if len(expert_scores) != num_layers_per_pipeline_stage:
53
+ raise ValueError(
54
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
55
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
56
+ f'{args.num_layers}\npipeline_model_parallel_size = '
57
+ f'{args.pipeline_model_parallel_size}\n'
58
+ 'num_layers_per_virtual_pipeline_stage'
59
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
60
+ )
61
+
62
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
63
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
64
+
65
+ tokens = expert_scores[0].shape[0]
66
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
67
+
68
+ # Concatenate the contributions of each layer and convert to
69
+ # the correct types and formats for the dot product.
70
+ expert_scores = torch.cat(expert_scores, dim=1)
71
+ if args.moe_lbl_in_fp32:
72
+ expert_scores = expert_scores.float()
73
+ if tokens != 0:
74
+ expert_scores = expert_scores.mean(dim=0)
75
+ else:
76
+ expert_scores = expert_scores.sum(dim=0)
77
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
78
+
79
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
80
+ assert tokens_per_expert.numel() == expected_values
81
+ assert expert_scores.numel() == expected_values
82
+
83
+ # Calculate the total scale across all factors.
84
+ #
85
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
86
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
87
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
88
+ scale = scale_numerator / scale_denominator
89
+ return scale * torch.dot(tokens_per_expert, expert_scores)
90
+
91
+
92
+ # NOTE: This class defines MoE expert computation, including expert model parallel
93
+ # communication. When using FSDP on top of MegaBlocks this is the module that should
94
+ # be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
95
+ # parallel all2all.
96
+ class ParallelMLP(torch.nn.Module):
97
+
98
+ def __init__(self, args: Arguments):
99
+ super(ParallelMLP, self).__init__()
100
+ self.args = args
101
+
102
+ # Calculate the number of experts in total and the number of experts
103
+ # owned by this rank.
104
+ # world_size = mpu.get_expert_parallel_world_size(args)
105
+ self.num_experts = args.moe_num_experts
106
+ self.top_k = self.args.moe_top_k
107
+
108
+ # Calculate the number of bits needed to represent the expert indices
109
+ # so that we can pass it to radix sort.
110
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
111
+
112
+ # Expert MLP.
113
+ self.mlp = mlp.MLP(args)
114
+
115
+ self.bias: Optional[torch.Tensor]
116
+ if self.args.bias:
117
+ # Note that the output bias is not parallelized with expert
118
+ # model parallelism.
119
+ self.bias = torch.nn.Parameter(
120
+ torch.empty(
121
+ args.hidden_size,
122
+ device=args.device,
123
+ dtype=common.dtype(args),
124
+ ),
125
+ )
126
+ torch.nn.init.zeros_(self.bias)
127
+ else:
128
+ self.register_parameter('bias', None)
129
+
130
+ # Select the forward function for the operating mode.
131
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
132
+
133
+ def expert_capacity(self, tokens: int) -> int:
134
+ world_size = mpu.get_expert_parallel_world_size(self.args)
135
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
136
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
137
+
138
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
139
+ """Calculate the load balancing loss contribution."""
140
+ assert len(expert_scores.size()) == 2
141
+ tokens, num_experts = expert_scores.size()
142
+ assert num_experts == self.num_experts
143
+ assert len(tokens_per_expert.size()) == 1
144
+ num_experts, = tokens_per_expert.size()
145
+ assert num_experts == self.num_experts
146
+ scale = self.num_experts / (tokens * self.top_k)
147
+ return scale * torch.dot(
148
+ tokens_per_expert.to(expert_scores.dtype),
149
+ expert_scores.mean(dim=0),
150
+ )
151
+
152
+ def indices_and_bins(self,
153
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
154
+ # Sort the expert ids to produce the scatter/gather
155
+ # indices for the permutation.
156
+ #
157
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
158
+ # prior? Could we place the `torch.max` operation to return
159
+ # 32-bit expert indices?
160
+ top_expert = top_expert.int()
161
+ output = ops.sort(top_expert, self.sort_end_bit)
162
+ assert output is not None
163
+ bin_ids, indices = output
164
+
165
+ # Histogram the expert ids to identify the number of
166
+ # tokens routed to each expert.
167
+ #
168
+ # TODO(tgale): Does the sorted data produce a more favorable
169
+ # data distribution for histogram? Or is the op parallelism
170
+ # worth more?
171
+ tokens_per_expert = ops.histogram(top_expert, self.num_experts)
172
+
173
+ # Calculate the bin bounds for the sorted tokens.
174
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
175
+ assert bins is not None
176
+ bins = bins.view(1) if not len(bins.size()) else bins
177
+
178
+ assert isinstance(indices, torch.Tensor)
179
+ assert isinstance(bin_ids, torch.Tensor)
180
+ assert isinstance(bins, torch.Tensor)
181
+ assert isinstance(tokens_per_expert, torch.Tensor)
182
+
183
+ return indices, bin_ids, bins, tokens_per_expert
184
+
185
+ def permute_and_compute(
186
+ self,
187
+ x: torch.Tensor,
188
+ tokens_per_expert: int, # unused
189
+ indices: torch.Tensor,
190
+ bin_ids: torch.Tensor, # unused
191
+ expert_weights: torch.Tensor,
192
+ bins: torch.Tensor,
193
+ expert_capacity: int,
194
+ top_k: int,
195
+ ):
196
+ # Route the tokens for MoE computation.
197
+ x = x.view(-1, x.shape[-1])
198
+ output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
199
+ assert output is not None
200
+ x = output
201
+
202
+ # Perform the expert computation. Note that we don't
203
+ # use biases for these linear operations.
204
+ x = self.mlp(x)
205
+
206
+ # Un-route the data for the MoE output.
207
+ return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
208
+
209
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
210
+ # x: [sl, bs, hs]
211
+ # expert_weights: [sl * bs, top-k]
212
+ # top_experts: [sl * bs, top-k]
213
+ expert_weights = expert_weights.flatten()
214
+ top_experts = top_experts.flatten()
215
+ with torch.no_grad():
216
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
217
+
218
+ # If expert_capacity is set to zero, set the number of tokens
219
+ # per expert to the maximum we need to avoid dropping tokens.
220
+ sl, bs, _ = x.size()
221
+ expert_capacity = self.expert_capacity(sl * bs)
222
+ if expert_capacity == 0:
223
+ expert_capacity = torch.max(tokens_per_expert).item()
224
+
225
+ x = self.permute_and_compute(
226
+ x,
227
+ tokens_per_expert,
228
+ indices,
229
+ bin_ids,
230
+ expert_weights,
231
+ bins,
232
+ expert_capacity,
233
+ self.top_k,
234
+ )
235
+ return x, tokens_per_expert
236
+
237
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
238
+ # NOTE: This function implements the same computation as forward_once
239
+ # but with expert model parallelism.
240
+ #
241
+ # 1. Permute the tokens locally so that they are grouped by their
242
+ # expert assignments. This allows us to transfer all of the tokens
243
+ # for a remote device in one communication primitive.
244
+ #
245
+ # 2. Permute the tokens across the expert parallel devices. After
246
+ # this is completed each device has all of the tokens assigned to
247
+ # its set of experts in its local HBM.
248
+ #
249
+ # 3. Permute the tokens locally so that they are grouped by their
250
+ # expert assignement. After the distributed permutation the tokens
251
+ # are grouped by which device they came from. We re-order them
252
+ # locally to allow for efficient computation.
253
+ #
254
+ # After this series of permutations we compute the linear layers
255
+ # and then repeat these three steps in reverse to produce the final
256
+ # output.
257
+ #
258
+ # Compute the mapping of local tokens to experts.
259
+ expert_weights = expert_weights.flatten()
260
+ top_experts = top_experts.flatten()
261
+ with torch.no_grad():
262
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
263
+
264
+ # If we're sharding the experts along the hidden dimension
265
+ # multiple devices own parts of the same sets of experts.
266
+ # Replicate the token counts so every device gets the counts.
267
+ repeated_tokens_per_expert = ops.repeat(
268
+ tokens_per_expert,
269
+ (mpu.hidden_sharding_degree(self.args),),
270
+ )
271
+
272
+ # Pass token count information to the device on which the
273
+ # target expert resides.
274
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
275
+ tpe_handle = dist.all_to_all_single(
276
+ parallel_tokens_per_expert,
277
+ repeated_tokens_per_expert,
278
+ group=self.args.expert_parallel_group,
279
+ async_op=True,
280
+ )
281
+
282
+ # Permute locally and without any padding so that tokens for each
283
+ # parallel device are stored contiguously.
284
+ #
285
+ # This view updates the shape of the tensor from [sl, bs, hs] to
286
+ # [sl * bs, hs] prior to the permutation.
287
+ x = x.view(-1, x.shape[-1])
288
+ output = ops.gather(x, indices, bin_ids, bins, self.top_k)
289
+ assert output is not None
290
+ x = output
291
+
292
+ # Compute the number of tokens that will be received from each
293
+ # device and permute the input data across the devices.
294
+ with torch.no_grad():
295
+ tpe_handle.wait()
296
+ experts_per_rank = mpu.experts_per_rank(self.args)
297
+
298
+ # Reshape to [world_size, num_experts_per_rank].
299
+ world_size = mpu.get_expert_parallel_world_size(self.args)
300
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
301
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
302
+
303
+ # TODO(tgale): It might be faster to do this on the GPU and
304
+ # then communicate the results back to the host.
305
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
306
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
307
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
308
+
309
+ # Convert the send/recv counts to lists.
310
+ send_counts = send_counts.tolist()
311
+ recv_counts = recv_counts.tolist()
312
+ tokens_received = sum(recv_counts)
313
+
314
+ # If we're sharding the experts along the hidden dimension
315
+ # multiple devices own parts of the same sets of experts.
316
+ # Replicate the token counts so devices that share experts
317
+ # get all of the tokens assigned to them.
318
+ #
319
+ # TODO(tgale): Fuse this into the prior, local permutation.
320
+ x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
321
+
322
+ # Start the cross-device permutation asynchronously so we can
323
+ # overlap communication with computation.
324
+ parallel_x, parallel_x_handle = all_to_all(
325
+ x,
326
+ recv_counts,
327
+ send_counts,
328
+ self.args.expert_parallel_group,
329
+ async_op=True,
330
+ )
331
+
332
+ with torch.no_grad():
333
+ # After we do the cross-device permutation we have the tokens on the
334
+ # correct device but not yet grouped by expert because we received
335
+ # tokens from each device as contiguous chunks. To group the tokens
336
+ # for expert computation we'll do one more local permutation. The
337
+ # rest of this torch.no_grad() scope sets up the indices and bins
338
+ # for this permutation.
339
+ replicate_bins = ops.inclusive_cumsum(
340
+ parallel_tokens_per_expert.flatten(),
341
+ 0,
342
+ )
343
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
344
+
345
+ # Construct the expert indices for the permuted tokens.
346
+ parallel_top_expert = torch.remainder(
347
+ torch.arange(
348
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
349
+ dtype=torch.int32,
350
+ device=indices.device,
351
+ ),
352
+ mpu.experts_per_rank(self.args),
353
+ )
354
+ parallel_top_expert = ops.replicate(
355
+ parallel_top_expert.unsqueeze(dim=0),
356
+ replicate_bins,
357
+ tokens_received,
358
+ ).flatten()
359
+
360
+ # TODO(tgale): The sort_end_bit here can be reduced.
361
+ parallel_bin_ids, parallel_indices = ops.sort(
362
+ parallel_top_expert,
363
+ self.sort_end_bit,
364
+ )
365
+
366
+ # Calculate the bins boundaries from the token counts.
367
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
368
+ dim=0,
369
+ dtype=torch.int,
370
+ )
371
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
372
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
373
+
374
+ # If expert_capacity is set to zero, set the number of tokens
375
+ # per expert to the maximum we need to avoid dropping tokens.
376
+ tokens, _ = x.size()
377
+ expert_capacity = self.expert_capacity(tokens)
378
+ if expert_capacity == 0:
379
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
380
+
381
+ # Locally permute the tokens and perform the expert computation.
382
+ # Block to make sure that the cross-device permutation is complete.
383
+ if self.args.mlp_impl == 'grouped':
384
+ # GroupedMLP requires counts on CPU. We can use the tensor already
385
+ # moved to CPU for the prior all_to_all, which avoids an extra
386
+ # device synchronization.
387
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
388
+ dim=0,
389
+ dtype=torch.int,
390
+ )
391
+ parallel_x_handle.wait()
392
+ parallel_x = self.permute_and_compute(
393
+ parallel_x,
394
+ parallel_tokens_per_expert,
395
+ parallel_indices,
396
+ parallel_bin_ids,
397
+ None, # expert_weights
398
+ parallel_bins,
399
+ expert_capacity,
400
+ top_k=1,
401
+ )
402
+
403
+ # Un-permute the tokens across the devices.
404
+ x, _ = all_to_all(
405
+ parallel_x,
406
+ send_counts,
407
+ recv_counts,
408
+ self.args.expert_parallel_group,
409
+ )
410
+
411
+ # Reduce along the hidden sharding to get the final outputs.
412
+ #
413
+ # TODO(tgale): Fuse this into the following local permutation.
414
+ shape = (
415
+ mpu.hidden_sharding_degree(self.args),
416
+ -1,
417
+ self.args.hidden_size,
418
+ )
419
+ x = ops.sum(x.view(shape), dim=0)
420
+
421
+ # Un-permute locally to setup for the next series of operations.
422
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
423
+ return x, tokens_per_expert.flatten()
424
+
425
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
426
+ in_shape = x.size()
427
+
428
+ # Compute the experts.
429
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
430
+ if self.training and self.args.moe_loss_weight > 0:
431
+ save_load_balancing_loss((tokens_per_expert, scores))
432
+ x = x.view(in_shape)
433
+ if self.bias is not None:
434
+ if self.args.return_bias:
435
+ return x, self.bias
436
+ return x + self.bias
437
+ return x
438
+
439
+
440
+ class MoE(torch.nn.Module):
441
+
442
+ def __init__(self, args: Arguments):
443
+ super(MoE, self).__init__()
444
+
445
+ # Token router.
446
+ self.router = router.LearnedRouter(args)
447
+
448
+ # Expert computation helper.
449
+ self.experts = self._init_experts_mlp(args)
450
+
451
+ self.shared_expert = None
452
+ if args.shared_expert:
453
+ # SharedExpert computation helper.
454
+ self.shared_expert = sharedexpert_registry.get(args)
455
+
456
+ def _init_experts_mlp(self, args: Arguments):
457
+ return ParallelMLP(args)
458
+
459
+ def forward(self, x: torch.Tensor):
460
+ # NOTE: If we're going to cast the activations to lower precision
461
+ # do it before we permute the tokens to save bandwidth.
462
+ x = common.cast_if_autocast_enabled(x)
463
+
464
+ # Compute the expert scores and assignments.
465
+ scores, expert_weights, top_experts = self.router(x)
466
+
467
+ # Compute the experts.
468
+ out = self.experts(x, scores, expert_weights, top_experts)
469
+ if self.shared_expert is not None:
470
+ shared_expert_out = self.shared_expert(x)
471
+ out = self.shared_expert.add_experts_sharedexpert(
472
+ shared_expert_out,
473
+ out,
474
+ )
475
+ return out
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ from megablocks.layers.arguments import Arguments
10
+
11
+
12
+ class MoeParam(torch.Tensor):
13
+
14
+ def __init__(self):
15
+ super().__init__(self)
16
+ self.expert_model_parallel: bool
17
+
18
+
19
+ def is_moe_param(tensor: torch.Tensor) -> bool:
20
+ return hasattr(tensor, 'expert_model_parallel')
21
+
22
+
23
+ def get_expert_parallel_world_size(args: Arguments) -> int:
24
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
25
+
26
+
27
+ def get_expert_parallel_rank(args: Arguments) -> int:
28
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
29
+
30
+
31
+ def set_expert_model_parallel_attributes(
32
+ tensor: torch.Tensor,
33
+ is_parallel: bool,
34
+ ):
35
+ assert not hasattr(tensor, 'expert_model_parallel')
36
+ setattr(tensor, 'expert_model_parallel', is_parallel)
37
+
38
+
39
+ def param_is_expert_model_parallel(param: MoeParam) -> bool:
40
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
41
+
42
+
43
+ def copy_expert_model_parallel_attributes(
44
+ destination_tensor: torch.Tensor,
45
+ source_tensor: torch.Tensor,
46
+ ):
47
+ if hasattr(source_tensor, 'expert_model_parallel'):
48
+ setattr(
49
+ destination_tensor,
50
+ 'expert_model_parallel',
51
+ getattr(source_tensor, 'expert_model_parallel'),
52
+ )
53
+
54
+
55
+ def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
56
+ world_size = dist.get_world_size(group)
57
+ rank = dist.get_rank(group)
58
+ for i in range(world_size):
59
+ dist.barrier(group)
60
+ if i == rank:
61
+ print(f'rank = {rank}', *x)
62
+
63
+
64
+ # Helpers for expert/tensor sharding.
65
+ def expert_sharding_degree(args: Arguments) -> int:
66
+ world_size = get_expert_parallel_world_size(args)
67
+ esd = min(world_size, args.moe_num_experts)
68
+
69
+ if (args.moe_num_experts % esd) != 0:
70
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
71
+ return esd
72
+
73
+
74
+ def hidden_sharding_degree(args: Arguments) -> int:
75
+ world_size = get_expert_parallel_world_size(args)
76
+ esd = expert_sharding_degree(args)
77
+ hsd = world_size // esd
78
+
79
+ if (args.ffn_hidden_size % hsd) != 0:
80
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
81
+ if (esd * hsd) != world_size:
82
+ raise ValueError(
83
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
84
+ )
85
+ return hsd
86
+
87
+
88
+ def experts_per_rank(args: Arguments) -> int:
89
+ return args.moe_num_experts // expert_sharding_degree(args)
90
+
91
+
92
+ def features_per_rank(args: Arguments) -> int:
93
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from megablocks.layers import common
8
+ from megablocks.layers.arguments import Arguments
9
+
10
+ _ROUTER_LOGITS = []
11
+
12
+
13
+ def _save_router_logits(logits: torch.Tensor, args: Arguments):
14
+ if args.moe_zloss_weight == 0:
15
+ return
16
+ global _ROUTER_LOGITS
17
+ _ROUTER_LOGITS.append(logits)
18
+
19
+
20
+ def clear_router_zloss():
21
+ global _ROUTER_LOGITS
22
+ _ROUTER_LOGITS.clear()
23
+
24
+
25
+ def batched_router_zloss(args: Arguments):
26
+ global _ROUTER_LOGITS
27
+
28
+ if args.moe_zloss_weight == 0:
29
+ import warnings
30
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
31
+ return 0
32
+
33
+ logits_per_router = _ROUTER_LOGITS
34
+
35
+ if args.moe_zloss_in_fp32:
36
+ logits_per_router = [logits.float() for logits in logits_per_router]
37
+
38
+ unscaled_zloss_per_router = torch.stack([
39
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
40
+ ])
41
+
42
+ return args.moe_zloss_weight * unscaled_zloss_per_router
43
+
44
+
45
+ # NOTE: To enable end-to-end benchmarking without convergence we
46
+ # support a flag to force the router to assign tokens uniformly
47
+ # across the experts. We do this with a custom autograd operation
48
+ # so that PyTorch still executes the full set of router operation.
49
+ class _UniformExpertAssignment(torch.autograd.Function):
50
+
51
+ @staticmethod
52
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
53
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
54
+ out = torch.remainder(out, num_experts)
55
+ return out.view(x.shape)
56
+
57
+
58
+ _uniform_expert_assignment = _UniformExpertAssignment.apply
59
+
60
+
61
+ class LearnedRouter(torch.nn.Module):
62
+
63
+ def __init__(self, args: Arguments):
64
+ super().__init__()
65
+ self.args = args
66
+
67
+ # Learned router parameters.
68
+ #
69
+ # NOTE: This weight matrix is not parallelized with expert model
70
+ # parallelism. Each device needs the entire router weight matrix
71
+ # so that it can route its batch of data correctly.
72
+ self.layer = torch.nn.Linear(
73
+ args.hidden_size,
74
+ args.moe_num_experts,
75
+ bias=False,
76
+ dtype=common.dtype(args),
77
+ device=args.device,
78
+ )
79
+ args.init_method(self.layer.weight)
80
+
81
+ def jitter(self, x: torch.Tensor):
82
+ low: float = 1.0 - self.args.moe_jitter_eps
83
+ high: float = 1.0 + self.args.moe_jitter_eps
84
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
85
+ return low + noise * (high - low)
86
+
87
+ def _top_k(self, scores: torch.Tensor):
88
+ if self.args.moe_top_k == 1:
89
+ return scores.max(dim=-1, keepdim=True)
90
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
91
+
92
+ def forward(self, x: torch.Tensor):
93
+ if self.training and self.args.moe_jitter_eps is not None:
94
+ x = x * self.jitter(x)
95
+
96
+ logits = self.layer(x.view(-1, x.shape[-1]))
97
+ _save_router_logits(logits, self.args)
98
+ scores = logits.softmax(dim=-1)
99
+ expert_weights, expert_indices = self._top_k(scores)
100
+ if self.args.moe_normalize_expert_weights:
101
+ expert_weights = expert_weights / torch.norm(
102
+ expert_weights,
103
+ p=self.args.moe_normalize_expert_weights,
104
+ dim=-1,
105
+ keepdim=True,
106
+ )
107
+
108
+ expert_indices = (
109
+ _uniform_expert_assignment(
110
+ expert_indices,
111
+ self.args.moe_num_experts,
112
+ ) if self.args.uniform_expert_assignment else expert_indices
113
+ )
114
+ return scores, expert_weights, expert_indices
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Union
5
+
6
+ from megablocks.layers import glu, mlp
7
+ from megablocks.layers.arguments import Arguments
8
+
9
+ _REGISTRY = {
10
+ 'mlp': mlp.SharedMLP,
11
+ 'glu': glu.SharedGLU,
12
+ }
13
+
14
+
15
+ def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
16
+ """Returns an SharedMLP for use in a dMoE instance.
17
+
18
+ Uses the provided arguments to instantiate the appropriate
19
+ SharedMLP instance.
20
+
21
+ Args:
22
+ args: propagated Arguments dataclass.
23
+
24
+ Returns:
25
+ An instantiated SharedMLP constructed using the input args.
26
+ """
27
+ if args.mlp_type not in _REGISTRY:
28
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
29
+
30
+ return _REGISTRY[args.mlp_type](args)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from megablocks.ops.binned_gather import binned_gather
5
+ from megablocks.ops.binned_scatter import binned_scatter
6
+ from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum
7
+ from megablocks.ops.gather import gather
8
+ from megablocks.ops.histogram import histogram
9
+ from megablocks.ops.padded_gather import padded_gather
10
+ from megablocks.ops.padded_scatter import padded_scatter
11
+ from megablocks.ops.repeat import repeat
12
+ from megablocks.ops.replicate import replicate
13
+ from megablocks.ops.round_up import round_up
14
+ from megablocks.ops.scatter import scatter
15
+ from megablocks.ops.sort import sort
16
+ from megablocks.ops.sum import sum
17
+ from megablocks.ops.topology import topology
18
+
19
+ __all__ = [
20
+ 'binned_gather',
21
+ 'binned_scatter',
22
+ 'exclusive_cumsum',
23
+ 'inclusive_cumsum',
24
+ 'gather',
25
+ 'histogram',
26
+ 'padded_gather',
27
+ 'padded_scatter',
28
+ 'repeat',
29
+ 'replicate',
30
+ 'round_up',
31
+ 'scatter',
32
+ 'sort',
33
+ 'sum',
34
+ 'topology',
35
+ ]
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+ from megablocks import benchmark_util
8
+ from megablocks.layers.all_to_all import all_to_all
9
+
10
+ _ALL_TO_ALL_BENCHMARK = (
11
+ (8, 1024),
12
+ (16, 1024),
13
+ (32, 1024),
14
+ (64, 1024),
15
+ (128, 1024),
16
+ (256, 1024),
17
+ (512, 1024),
18
+ (1024, 1024),
19
+ (2 * 1024, 1024),
20
+ (4 * 1024, 1024),
21
+ (8 * 1024, 1024),
22
+ (16 * 1024, 1024),
23
+ (32 * 1024, 1024),
24
+ (64 * 1024, 1024),
25
+ (128 * 1024, 1024),
26
+ (256 * 1024, 1024),
27
+ (512 * 1024, 1024),
28
+ (1024 * 1024, 1024),
29
+ )
30
+
31
+
32
+ def benchmark_all_to_all(group, sl, hs):
33
+ world_size = dist.get_world_size(group)
34
+ assert (sl % world_size) == 0
35
+ send_recv_sizes = [sl // world_size] * world_size
36
+
37
+ x = torch.randn((sl, hs)).cuda().half()
38
+
39
+ details = {
40
+ 'world_size': world_size,
41
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
42
+ }
43
+
44
+ def benchmark():
45
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
46
+
47
+ time, std = benchmark_util.benchmark_function(benchmark)
48
+
49
+ if dist.get_rank(group) == 0:
50
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
51
+
52
+
53
+ if __name__ == '__main__':
54
+ assert dist.is_available()
55
+ group = dist.init_process_group(backend='nccl')
56
+ local_rank = dist.get_rank(group)
57
+ torch.cuda.set_device(local_rank)
58
+
59
+ for args in _ALL_TO_ALL_BENCHMARK:
60
+ benchmark_all_to_all(group, *args)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from stk.backend.autocast import custom_bwd, custom_fwd
7
+
8
+ from megablocks.backend import kernels
9
+
10
+
11
+ # Autograd wrapper for binned_gather kernel.
12
+ class BinnedGatherOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bins: torch.Tensor,
21
+ bin_size: int,
22
+ top_k: int,
23
+ ):
24
+ ctx.save_for_backward(indices, bins)
25
+ ctx.top_k = top_k
26
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
27
+
28
+ @staticmethod
29
+ @custom_bwd
30
+ def backward(ctx: Any, grad: torch.Tensor):
31
+ grad = grad.contiguous()
32
+ indices, bins = ctx.saved_tensors
33
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
34
+ return out, None, None, None, None
35
+
36
+
37
+ binned_gather = BinnedGatherOp.apply
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from stk.backend.autocast import custom_bwd, custom_fwd
7
+
8
+ from megablocks.backend import kernels
9
+
10
+
11
+ # Autograd wrapper for binned_scatter kernel.
12
+ class BinnedScatterOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ weights: torch.Tensor,
21
+ bins: torch.Tensor,
22
+ top_k: int,
23
+ ):
24
+ assert len(x.size()) == 3
25
+ ctx.bin_size = x.size(1)
26
+ ctx.top_k = top_k
27
+
28
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
29
+ # calculate the gradient w.r.t. 'weights'.
30
+ ctx.save_for_backward(x, indices, weights, bins)
31
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
32
+
33
+ @staticmethod
34
+ @custom_bwd
35
+ def backward(ctx: Any, grad: torch.Tensor):
36
+ grad = grad.contiguous()
37
+ x, indices, weights, bins = ctx.saved_tensors
38
+ out = kernels.binned_gather(
39
+ grad,
40
+ indices,
41
+ weights,
42
+ bins,
43
+ ctx.bin_size,
44
+ ctx.top_k,
45
+ )
46
+
47
+ wgrad = None
48
+ if ctx.needs_input_grad[2]:
49
+ wgrad = kernels.binned_scatter_wgrad(
50
+ x,
51
+ grad,
52
+ indices,
53
+ bins,
54
+ ctx.top_k,
55
+ )
56
+ return out, None, wgrad, None, None
57
+
58
+
59
+ binned_scatter = BinnedScatterOp.apply
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ # import megablocks_ops as ops # type: ignore
14
+ from megablocks._ops import ops # type: ignore
15
+ except ModuleNotFoundError as e:
16
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
17
+
18
+
19
+ # Autograd wrappers for cumsum kernels.
20
+ # NOTE: Does not support gradients.
21
+ class ExclusiveCumsumOp(torch.autograd.Function):
22
+
23
+ @staticmethod
24
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
25
+ if len(x.size()) == 1:
26
+ x = x.view([1, -1])
27
+ out = torch.empty_like(x)
28
+ ops.exclusive_cumsum(x, 1, out)
29
+ return out.squeeze()
30
+ out = torch.empty_like(x)
31
+ ops.exclusive_cumsum(x, dim, out)
32
+ return out
33
+
34
+
35
+ exclusive_cumsum = ExclusiveCumsumOp.apply
36
+
37
+
38
+ class InclusiveCumsumOp(torch.autograd.Function):
39
+
40
+ @staticmethod
41
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
42
+ if len(x.size()) == 1:
43
+ x = x.view([1, -1])
44
+ out = torch.empty_like(x)
45
+ ops.inclusive_cumsum(x, 1, out)
46
+ return out.squeeze()
47
+ out = torch.empty_like(x)
48
+ ops.inclusive_cumsum(x, dim, out)
49
+ return out
50
+
51
+
52
+ inclusive_cumsum = InclusiveCumsumOp.apply
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from stk.backend.autocast import custom_bwd, custom_fwd
7
+
8
+ from megablocks.backend import kernels
9
+
10
+
11
+ # Autograd wrapper for gather kernel.
12
+ class GatherOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bin_ids: torch.Tensor,
21
+ bins: torch.Tensor,
22
+ top_k: int,
23
+ ):
24
+ ctx.save_for_backward(indices, bin_ids, bins)
25
+ ctx.top_k = top_k
26
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
27
+
28
+ @staticmethod
29
+ @custom_bwd
30
+ def backward(ctx: Any, grad: torch.Tensor):
31
+ grad = grad.contiguous()
32
+
33
+ indices, bin_ids, bins = ctx.saved_tensors
34
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
35
+ return out, None, None, None, None, None
36
+
37
+
38
+ gather = GatherOp.apply
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ from megablocks._ops import ops # type: ignore
14
+ except ModuleNotFoundError as e:
15
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
+
17
+
18
+ # Autograd wrapper for histogram kernel.
19
+ # NOTE: Does not support gradients.
20
+ class HistogramOp(torch.autograd.Function):
21
+
22
+ @staticmethod
23
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
24
+ return ops.histogram(x, max_val)
25
+
26
+
27
+ histogram = HistogramOp.apply
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import numpy as np
7
+ import torch
8
+ from absl.testing import parameterized
9
+
10
+ from megablocks import ops
11
+
12
+ _HISTOGRAM_TESTS = (
13
+ (16384, torch.int32, 2),
14
+ (16384, torch.int32, 4),
15
+ (16384, torch.int32, 8),
16
+ (16384, torch.int32, 16),
17
+ (16384, torch.int32, 32),
18
+ (16384, torch.int32, 64),
19
+ (16384, torch.int32, 128),
20
+ (16384, torch.int32, 256),
21
+ )
22
+
23
+
24
+ def benchmark_function(fn, iterations=10):
25
+ # Run once to get rid of startup overhead.
26
+ fn()
27
+ times = []
28
+ for _ in range(iterations):
29
+ start = torch.cuda.Event(enable_timing=True)
30
+ end = torch.cuda.Event(enable_timing=True)
31
+ start.record()
32
+ fn()
33
+ end.record()
34
+ torch.cuda.synchronize()
35
+ times.append(start.elapsed_time(end))
36
+ times = np.array(times)
37
+ return times.mean(), times.std(), times.max(), times.min()
38
+
39
+
40
+ def log_benchmark(arguments, mean_t, std_t):
41
+ print('=' * 60)
42
+ print('Benchmark Parameters:')
43
+ for (key, value) in arguments.items():
44
+ print(f'{key} = {value}')
45
+ print('Results:')
46
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
47
+ print('=' * 60)
48
+
49
+
50
+ class HistogramBenchmark(parameterized.TestCase):
51
+
52
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
53
+ def testHistogram(self, n, dtype, max_val):
54
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
55
+
56
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
57
+ arguments = {
58
+ 'n': n,
59
+ 'dtype': dtype,
60
+ 'max_val': max_val,
61
+ }
62
+ log_benchmark(arguments, mean_t, std_t)
63
+
64
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
65
+ def testTorchHistogram(self, n, dtype, max_val):
66
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
67
+
68
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
69
+ arguments = {
70
+ 'n': n,
71
+ 'dtype': dtype,
72
+ 'max_val': max_val,
73
+ }
74
+ log_benchmark(arguments, mean_t, std_t)
75
+
76
+
77
+ if __name__ == '__main__':
78
+ unittest.main()
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import stk
7
+ import torch
8
+ from absl.testing import parameterized
9
+
10
+ from megablocks import benchmark_util, ops
11
+
12
+
13
+ # Calling tensor.t() calls tensor.transpose(0, 1) which calls
14
+ # torch.as_strided(...). Circumvent this chain to avoid an overhead
15
+ # this adds.
16
+ def transpose_view(x):
17
+ return torch.as_strided(
18
+ x,
19
+ (x.shape[1], x.shape[0]),
20
+ (x.stride()[1], x.stride()[0]),
21
+ )
22
+
23
+
24
+ _MATMUL_TESTS = (
25
+ (64 * 1024, 512, 2048, 64),
26
+ (32 * 1024, 768, 3072, 64),
27
+ (8 * 1024, 1024, 4096, 64),
28
+ (4 * 2048, 4096, 4 * 4096, 4),
29
+ )
30
+
31
+
32
+ def log_benchmark(name, arguments, time, std, flops):
33
+ benchmark_util.log_benchmark(name, arguments, time, std)
34
+ print('flops = {:.2f}B'.format(flops / 1e9))
35
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
36
+ print('=' * 60)
37
+
38
+
39
+ class MatmulBenchmark(parameterized.TestCase):
40
+
41
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
42
+ blocking = 128
43
+ padded_tokens, _ = x.size()
44
+ assert padded_tokens % blocking == 0
45
+ assert fhs % blocking == 0
46
+
47
+ # Offsets for the sparse matrix. All rows have the
48
+ # same number of nonzero blocks dictated by the
49
+ # dimensionality of a single expert.
50
+ block_rows = padded_tokens // blocking
51
+ blocks_per_row = fhs // blocking
52
+ offsets = torch.arange(
53
+ 0,
54
+ block_rows * blocks_per_row + 1,
55
+ blocks_per_row,
56
+ dtype=torch.int32,
57
+ device=x.device,
58
+ )
59
+
60
+ # Indices for the sparse matrix. The indices for
61
+ # the intermediate matrix are dynamic depending
62
+ # on the mapping of tokens to experts.
63
+ column_indices = ops.topology(
64
+ padded_bins,
65
+ blocking,
66
+ block_rows,
67
+ blocks_per_row,
68
+ )
69
+ data = torch.empty(
70
+ column_indices.numel(),
71
+ blocking,
72
+ blocking,
73
+ dtype=torch.float16,
74
+ device=x.device,
75
+ )
76
+ shape = (padded_tokens, fhs * ne)
77
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
78
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
79
+
80
+ def build_input_matrix(self, sl, hs, ne):
81
+ x = torch.randn((sl, hs)).cuda().half()
82
+
83
+ # Assign tokens to experts uniformly.
84
+ top_expert = torch.arange(0, sl).cuda().int() % ne
85
+
86
+ bin_ids, indices = ops.sort(top_expert)
87
+ tokens_per_expert = ops.histogram(top_expert, ne)
88
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
89
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
90
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
91
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
92
+ return out, padded_bins
93
+
94
+ def build_weight_matrix(self, ne, hs, fhs):
95
+ return torch.randn((hs, ne * fhs)).cuda().half()
96
+
97
+ @parameterized.parameters(*_MATMUL_TESTS)
98
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
99
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
100
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
101
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
102
+ w = transpose_view(w)
103
+
104
+ def benchmark():
105
+ return stk.ops.sdd(x, w, topo)
106
+
107
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
108
+ arguments = {
109
+ 'sequence_length': sl,
110
+ 'hidden_size': hs,
111
+ 'ffn_hidden_size': fhs,
112
+ 'num_experts': ne,
113
+ }
114
+ log_benchmark(
115
+ '0::Fwd::SDD::NT',
116
+ arguments,
117
+ mean_t,
118
+ std_t,
119
+ x.numel() * fhs * 2,
120
+ )
121
+
122
+ @parameterized.parameters(*_MATMUL_TESTS)
123
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
124
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
125
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
126
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
127
+
128
+ def benchmark():
129
+ return stk.ops.dsd(topo, w)
130
+
131
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
132
+ arguments = {
133
+ 'sequence_length': sl,
134
+ 'hidden_size': hs,
135
+ 'ffn_hidden_size': fhs,
136
+ 'num_experts': ne,
137
+ }
138
+ log_benchmark(
139
+ '0::GradX::DSD::NN',
140
+ arguments,
141
+ mean_t,
142
+ std_t,
143
+ x.numel() * fhs * 2,
144
+ )
145
+
146
+ @parameterized.parameters(*_MATMUL_TESTS)
147
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
148
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
149
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
150
+ topo = topo.t()
151
+
152
+ def benchmark():
153
+ return stk.ops.dsd(topo, x)
154
+
155
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
156
+ arguments = {
157
+ 'sequence_length': sl,
158
+ 'hidden_size': hs,
159
+ 'ffn_hidden_size': fhs,
160
+ 'num_experts': ne,
161
+ }
162
+ log_benchmark(
163
+ '0::GradW::DSD::TN',
164
+ arguments,
165
+ mean_t,
166
+ std_t,
167
+ x.numel() * fhs * 2,
168
+ )
169
+
170
+ @parameterized.parameters(*_MATMUL_TESTS)
171
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
172
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
173
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
174
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
175
+
176
+ def benchmark():
177
+ return stk.ops.dsd(x, w)
178
+
179
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
180
+ arguments = {
181
+ 'sequence_length': sl,
182
+ 'hidden_size': hs,
183
+ 'ffn_hidden_size': fhs,
184
+ 'num_experts': ne,
185
+ }
186
+ log_benchmark(
187
+ '1::Fwd::DSD::NN',
188
+ arguments,
189
+ mean_t,
190
+ std_t,
191
+ x.nnz * hs * 2,
192
+ )
193
+
194
+ @parameterized.parameters(*_MATMUL_TESTS)
195
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
196
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
197
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
198
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
199
+ out = stk.ops.dsd(x, w)
200
+ w = transpose_view(w)
201
+
202
+ def benchmark():
203
+ return stk.ops.sdd(out, w, x)
204
+
205
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
206
+ arguments = {
207
+ 'sequence_length': sl,
208
+ 'hidden_size': hs,
209
+ 'ffn_hidden_size': fhs,
210
+ 'num_experts': ne,
211
+ }
212
+ log_benchmark(
213
+ '1::GradX::SDD::NT',
214
+ arguments,
215
+ mean_t,
216
+ std_t,
217
+ x.nnz * hs * 2,
218
+ )
219
+
220
+ @parameterized.parameters(*_MATMUL_TESTS)
221
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
222
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
223
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
224
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
225
+ out = stk.ops.dsd(x, w)
226
+ x = x.t()
227
+
228
+ def benchmark():
229
+ return stk.ops.dsd(x, out)
230
+
231
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
232
+ arguments = {
233
+ 'sequence_length': sl,
234
+ 'hidden_size': hs,
235
+ 'ffn_hidden_size': fhs,
236
+ 'num_experts': ne,
237
+ }
238
+ log_benchmark(
239
+ '1::GradW::DSD::TN',
240
+ arguments,
241
+ mean_t,
242
+ std_t,
243
+ x.nnz * hs * 2,
244
+ )
245
+
246
+ @parameterized.parameters(*_MATMUL_TESTS)
247
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
248
+ assert (sl % ne) == 0
249
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
250
+ w = torch.randn((ne, hs, fhs)).cuda().half()
251
+
252
+ w = w.transpose(1, 2).contiguous()
253
+ w = w.transpose(1, 2)
254
+
255
+ def benchmark():
256
+ return torch.bmm(x, w)
257
+
258
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
259
+ arguments = {
260
+ 'sequence_length': sl,
261
+ 'hidden_size': hs,
262
+ 'ffn_hidden_size': fhs,
263
+ 'num_experts': ne,
264
+ }
265
+ log_benchmark(
266
+ '0::Fwd:DDD::NT',
267
+ arguments,
268
+ mean_t,
269
+ std_t,
270
+ x.numel() * fhs * 2,
271
+ )
272
+
273
+ @parameterized.parameters(*_MATMUL_TESTS)
274
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
275
+ assert (sl % ne) == 0
276
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
277
+ w = torch.randn((ne, hs, fhs)).cuda().half()
278
+ out = torch.bmm(x, w)
279
+ w = w.transpose(1, 2).contiguous()
280
+
281
+ def benchmark():
282
+ return torch.bmm(out, w)
283
+
284
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
285
+ arguments = {
286
+ 'sequence_length': sl,
287
+ 'hidden_size': hs,
288
+ 'ffn_hidden_size': fhs,
289
+ 'num_experts': ne,
290
+ }
291
+ log_benchmark(
292
+ '0:GradX:DDD::NN',
293
+ arguments,
294
+ mean_t,
295
+ std_t,
296
+ x.numel() * fhs * 2,
297
+ )
298
+
299
+ @parameterized.parameters(*_MATMUL_TESTS)
300
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
301
+ assert (sl % ne) == 0
302
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
303
+ w = torch.randn((ne, hs, fhs)).cuda().half()
304
+ out = torch.bmm(x, w)
305
+ out = out.transpose(1, 2)
306
+
307
+ def benchmark():
308
+ return torch.bmm(out, x)
309
+
310
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
311
+ arguments = {
312
+ 'sequence_length': sl,
313
+ 'hidden_size': hs,
314
+ 'ffn_hidden_size': fhs,
315
+ 'num_experts': ne,
316
+ }
317
+ log_benchmark(
318
+ '0:GradW:DDD::TN',
319
+ arguments,
320
+ mean_t,
321
+ std_t,
322
+ x.numel() * fhs * 2,
323
+ )
324
+
325
+ @parameterized.parameters(*_MATMUL_TESTS)
326
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
327
+ assert (sl % ne) == 0
328
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
329
+ w = torch.randn((ne, fhs, hs)).cuda().half()
330
+
331
+ def benchmark():
332
+ return torch.bmm(x, w)
333
+
334
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
335
+ arguments = {
336
+ 'sequence_length': sl,
337
+ 'hidden_size': hs,
338
+ 'ffn_hidden_size': fhs,
339
+ 'num_experts': ne,
340
+ }
341
+ log_benchmark(
342
+ '1::Fwd::DDD::NN',
343
+ arguments,
344
+ mean_t,
345
+ std_t,
346
+ x.numel() * hs * 2,
347
+ )
348
+
349
+ @parameterized.parameters(*_MATMUL_TESTS)
350
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
351
+ assert (sl % ne) == 0
352
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
353
+ w = torch.randn((ne, fhs, hs)).cuda().half()
354
+ out = torch.bmm(x, w)
355
+ w = torch.transpose(w, 1, 2)
356
+
357
+ def benchmark():
358
+ return torch.bmm(out, w)
359
+
360
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
361
+ arguments = {
362
+ 'sequence_length': sl,
363
+ 'hidden_size': hs,
364
+ 'ffn_hidden_size': fhs,
365
+ 'num_experts': ne,
366
+ }
367
+ log_benchmark(
368
+ '1::GradX::DDD::NT',
369
+ arguments,
370
+ mean_t,
371
+ std_t,
372
+ x.numel() * hs * 2,
373
+ )
374
+
375
+ @parameterized.parameters(*_MATMUL_TESTS)
376
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
377
+ assert (sl % ne) == 0
378
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
379
+ w = torch.randn((ne, fhs, hs)).cuda().half()
380
+ out = torch.bmm(x, w)
381
+ x = torch.transpose(x, 1, 2)
382
+
383
+ def benchmark():
384
+ return torch.bmm(x, out)
385
+
386
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
387
+ arguments = {
388
+ 'sequence_length': sl,
389
+ 'hidden_size': hs,
390
+ 'ffn_hidden_size': fhs,
391
+ 'num_experts': ne,
392
+ }
393
+ log_benchmark(
394
+ '1::GradW::DDD::TN',
395
+ arguments,
396
+ mean_t,
397
+ std_t,
398
+ x.numel() * hs * 2,
399
+ )
400
+
401
+
402
+ if __name__ == '__main__':
403
+ unittest.main()
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from stk.backend.autocast import custom_bwd, custom_fwd
7
+
8
+ from megablocks.backend import kernels
9
+
10
+
11
+ # Autograd wrapper for padded_gather kernel.
12
+ class PaddedGatherOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bin_ids: torch.Tensor,
21
+ bins: torch.Tensor,
22
+ padded_bins: torch.Tensor,
23
+ top_k: int,
24
+ ):
25
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
26
+ ctx.top_k = top_k
27
+ return kernels.padded_gather(
28
+ x,
29
+ indices,
30
+ bin_ids,
31
+ None,
32
+ bins,
33
+ padded_bins,
34
+ top_k,
35
+ )
36
+
37
+ @staticmethod
38
+ @custom_bwd
39
+ def backward(ctx: Any, grad: torch.Tensor):
40
+ grad = grad.contiguous()
41
+
42
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
43
+ out = kernels.padded_scatter(
44
+ grad,
45
+ indices,
46
+ bin_ids,
47
+ None,
48
+ bins,
49
+ padded_bins,
50
+ ctx.top_k,
51
+ )
52
+ return out, None, None, None, None, None
53
+
54
+
55
+ padded_gather = PaddedGatherOp.apply
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from stk.backend.autocast import custom_bwd, custom_fwd
7
+
8
+ from megablocks.backend import kernels
9
+
10
+
11
+ # Autograd wrapper for padded_scatter kernel.
12
+ class PaddedScatterOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bin_ids: torch.Tensor,
21
+ weights: torch.Tensor,
22
+ bins: torch.Tensor,
23
+ padded_bins: torch.Tensor,
24
+ top_k: int,
25
+ ):
26
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
27
+ ctx.save_for_backward(
28
+ indices,
29
+ bin_ids,
30
+ weights,
31
+ bins,
32
+ padded_bins,
33
+ *maybe_x,
34
+ )
35
+ ctx.top_k = top_k
36
+ ctx.x_shape = x.shape
37
+ return kernels.padded_scatter(
38
+ x,
39
+ indices,
40
+ bin_ids,
41
+ weights,
42
+ bins,
43
+ padded_bins,
44
+ top_k,
45
+ )
46
+
47
+ @staticmethod
48
+ @custom_bwd
49
+ def backward(ctx: Any, grad: torch.Tensor):
50
+ grad = grad.contiguous()
51
+ saved_tensors = ctx.saved_tensors
52
+
53
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
54
+ dgrad = None
55
+ if ctx.needs_input_grad[0]:
56
+ dgrad = kernels.padded_gather(
57
+ grad,
58
+ indices,
59
+ bin_ids,
60
+ weights,
61
+ bins,
62
+ padded_bins,
63
+ ctx.top_k,
64
+ )
65
+
66
+ wgrad = None
67
+ if ctx.needs_input_grad[3]: # need wgrad
68
+ x = saved_tensors[-1]
69
+ wgrad = kernels.padded_scatter_wgrad(
70
+ x,
71
+ grad,
72
+ indices,
73
+ bin_ids,
74
+ bins,
75
+ padded_bins,
76
+ ctx.top_k,
77
+ )
78
+ return dgrad, None, None, wgrad, None, None, None, None
79
+
80
+
81
+ def padded_scatter(
82
+ x: torch.Tensor,
83
+ indices: torch.Tensor,
84
+ bin_ids: torch.Tensor,
85
+ weights: torch.Tensor,
86
+ bins: torch.Tensor,
87
+ padded_bins: torch.Tensor,
88
+ top_k: int,
89
+ ):
90
+ return PaddedScatterOp.apply(
91
+ x,
92
+ indices,
93
+ bin_ids,
94
+ weights,
95
+ bins,
96
+ padded_bins,
97
+ top_k,
98
+ )
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import torch
7
+ from absl.testing import parameterized
8
+
9
+ from megablocks import benchmark_util, ops
10
+
11
+ _PADDED_SCATTER_BENCHMARK = (
12
+ # dMoE-Medium, 8-way EMP.
13
+ (1024 * 16, 1024, 8, 4),
14
+ # dMoE-Medium, post-all-to-all.
15
+ (1024 * 16 * 4, 1024, 8, 1),
16
+ )
17
+
18
+
19
+ class PaddedScatterTest(parameterized.TestCase):
20
+
21
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
22
+ def testPaddedScatter(self, sl, hs, ne, top_k):
23
+ # Create the data and indices.
24
+ x = torch.randn((sl, hs)).cuda().half()
25
+
26
+ # Randomly assign tokens to experts.
27
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
28
+ bin_ids, indices = ops.sort(top_expert)
29
+ tokens_per_expert = ops.histogram(top_expert, ne)
30
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
31
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
32
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
33
+
34
+ # Sample weights for the scatter reduce.
35
+ weights = torch.rand((sl * top_k,)).cuda().half()
36
+
37
+ # Gather the data to prepare for backwards.
38
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
39
+
40
+ def benchmark():
41
+ return ops.padded_scatter(
42
+ x,
43
+ indices,
44
+ bin_ids,
45
+ weights,
46
+ bins,
47
+ padded_bins,
48
+ top_k,
49
+ )
50
+
51
+ time, std = benchmark_util.benchmark_function(benchmark)
52
+ benchmark_util.log_benchmark(
53
+ 'Padded Scatter',
54
+ {
55
+ 'sequence_length': sl,
56
+ 'hidden_size': hs,
57
+ 'num_experts': ne,
58
+ 'top_k': top_k,
59
+ },
60
+ time,
61
+ std,
62
+ )
63
+
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import torch
7
+ from absl.testing import parameterized
8
+
9
+ from megablocks import benchmark_util, ops
10
+
11
+ _PERMUTE_TESTS = (
12
+ (16384, 768, 2),
13
+ (16384, 768, 4),
14
+ (16384, 768, 8),
15
+ (16384, 768, 16),
16
+ (16384, 768, 32),
17
+ (16384, 768, 64),
18
+ (16384, 768, 128),
19
+ (16384 * 8, 768, 2),
20
+ (16384 * 8, 768, 4),
21
+ (16384 * 8, 768, 8),
22
+ (16384 * 8, 768, 16),
23
+ (16384 * 8, 768, 32),
24
+ (16384 * 8, 768, 64),
25
+ (16384 * 8, 768, 128),
26
+ )
27
+
28
+
29
+ class PermuteBenchmark(parameterized.TestCase):
30
+
31
+ @parameterized.parameters(*_PERMUTE_TESTS)
32
+ def testBinnedGather(self, sl, hs, ne):
33
+ # NOTE: Capacity factor == 1.
34
+ ec = sl // ne
35
+
36
+ # Create the data and indices.
37
+ x = torch.randn((sl, hs)).cuda().half()
38
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
39
+ bin_ids, indices = ops.sort(top_expert)
40
+ tokens_per_expert = ops.histogram(indices, ne)
41
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
42
+
43
+ def benchmark():
44
+ return ops.binned_gather(x, indices, bins, ec)
45
+
46
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
47
+ arguments = {
48
+ 'sequence_length': sl,
49
+ 'hidden_size': hs,
50
+ 'num_experts': ne,
51
+ }
52
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
53
+
54
+ @parameterized.parameters(*_PERMUTE_TESTS)
55
+ def testBinnedScatter(self, sl, hs, ne):
56
+ # NOTE: Capacity factor == 1.
57
+ ec = sl // ne
58
+
59
+ # Create the data and indices.
60
+ x = torch.randn((sl, hs)).cuda().half()
61
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
62
+ bin_ids, indices = ops.sort(top_expert)
63
+ tokens_per_expert = ops.histogram(indices, ne)
64
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
+ x = ops.binned_gather(x, indices, bins, ec)
66
+
67
+ def benchmark():
68
+ return ops.binned_scatter(x, indices, bins)
69
+
70
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
71
+ arguments = {
72
+ 'sequence_length': sl,
73
+ 'hidden_size': hs,
74
+ 'num_experts': ne,
75
+ }
76
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
77
+
78
+ @parameterized.parameters(*_PERMUTE_TESTS)
79
+ def testPaddedGather(self, sl, hs, ne):
80
+ # Create the data and indices.
81
+ x = torch.randn((sl, hs)).cuda().half()
82
+
83
+ # Randomly assign tokens to experts.
84
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
85
+ bin_ids, indices = ops.sort(top_expert)
86
+ tokens_per_expert = ops.histogram(top_expert, ne)
87
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
+
91
+ def benchmark():
92
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
93
+
94
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
95
+ arguments = {
96
+ 'sequence_length': sl,
97
+ 'hidden_size': hs,
98
+ 'num_experts': ne,
99
+ }
100
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
101
+
102
+ @parameterized.parameters(*_PERMUTE_TESTS)
103
+ def testPaddedScatter(self, sl, hs, ne):
104
+ # Create the data and indices.
105
+ x = torch.randn((sl, hs)).cuda().half()
106
+
107
+ # Randomly assign tokens to experts.
108
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
109
+ bin_ids, indices = ops.sort(top_expert)
110
+ tokens_per_expert = ops.histogram(top_expert, ne)
111
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
112
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
113
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
114
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
115
+
116
+ def benchmark():
117
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
118
+
119
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
+ arguments = {
121
+ 'sequence_length': sl,
122
+ 'hidden_size': hs,
123
+ 'num_experts': ne,
124
+ }
125
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
126
+
127
+ @parameterized.parameters(*_PERMUTE_TESTS)
128
+ def testCopy(self, sl, hs, ne):
129
+ # NOTE: Capacity factor == 1.
130
+ # ec = sl // ne
131
+
132
+ # Create the data and indices.
133
+ x = torch.randn((sl, hs)).cuda().half()
134
+ y = x.clone()
135
+
136
+ def benchmark():
137
+ return y.copy_(x)
138
+
139
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
140
+ arguments = {
141
+ 'sequence_length': sl,
142
+ 'hidden_size': hs,
143
+ 'num_experts': ne,
144
+ }
145
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
146
+
147
+
148
+ if __name__ == '__main__':
149
+ unittest.main()
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+
7
+ def repeat(x: torch.Tensor, tiling: torch.Size):
8
+ if all((t == 1 for t in tiling)):
9
+ return x
10
+ return x.repeat(*tiling)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ from megablocks._ops import ops # type: ignore
14
+ except ModuleNotFoundError as e:
15
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
+
17
+
18
+ # Autograd wrapper for replicate kernel.
19
+ class ReplicateOp(torch.autograd.Function):
20
+
21
+ @staticmethod
22
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
23
+ ctx.save_for_backward(bins)
24
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
25
+ ops.replicate_forward(x, bins, out)
26
+ return out
27
+
28
+ @staticmethod
29
+ def backward(ctx: Any, grad: torch.Tensor):
30
+ bins, = ctx.saved_tensors
31
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
32
+ ops.replicate_backward(grad, bins, out)
33
+ return out, None, None
34
+
35
+
36
+ replicate = ReplicateOp.apply
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+
7
+ def round_up(x: torch.Tensor, value: int):
8
+ assert isinstance(value, int)
9
+ assert x.dtype == torch.int32
10
+
11
+ # TODO(tgale): If this becomes and issue
12
+ # do this in a custom kernel. We only expect
13
+ # to use this on arrays of less than 1k elements.
14
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ from stk.backend.autocast import custom_bwd, custom_fwd
8
+
9
+ from megablocks.backend import kernels
10
+
11
+
12
+ # Autograd wrapper for scatter kernel.
13
+ class ScatterOp(torch.autograd.Function):
14
+
15
+ @staticmethod
16
+ @custom_fwd
17
+ def forward(
18
+ ctx: Any,
19
+ x: torch.Tensor,
20
+ indices: torch.Tensor,
21
+ bin_ids: torch.Tensor,
22
+ weights: torch.Tensor,
23
+ bins: torch.Tensor,
24
+ top_k: int,
25
+ ) -> torch.Tensor:
26
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
27
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
28
+ ctx.top_k = top_k
29
+ ctx.x_shape = x.shape
30
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
31
+
32
+ @staticmethod
33
+ @custom_bwd
34
+ def backward(ctx: Any, grad: torch.Tensor):
35
+ grad = grad.contiguous()
36
+ saved_tensors = ctx.saved_tensors
37
+
38
+ indices, bin_ids, weights, bins = saved_tensors[:4]
39
+ dgrad = None
40
+ if ctx.needs_input_grad[0]:
41
+ dgrad = kernels.gather(
42
+ grad,
43
+ indices,
44
+ bin_ids,
45
+ weights,
46
+ bins,
47
+ ctx.top_k,
48
+ )
49
+
50
+ wgrad = None
51
+ if ctx.needs_input_grad[3]: # need wgrad
52
+ x = saved_tensors[-1]
53
+ wgrad = kernels.scatter_wgrad(
54
+ x,
55
+ grad,
56
+ indices,
57
+ bin_ids,
58
+ bins,
59
+ ctx.top_k,
60
+ )
61
+ return dgrad, None, None, wgrad, None, None, None
62
+
63
+
64
+ def scatter(
65
+ x: torch.Tensor,
66
+ indices: torch.Tensor,
67
+ bin_ids: torch.Tensor,
68
+ weights: torch.Tensor,
69
+ bins: torch.Tensor,
70
+ top_k: int,
71
+ ) -> Optional[torch.Tensor]:
72
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Optional, Tuple
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ from megablocks._ops import ops # type: ignore
14
+ except ModuleNotFoundError as e:
15
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
+
17
+ _BITS_FOR_DTYPE = {
18
+ torch.int16: 16,
19
+ torch.int32: 32,
20
+ torch.int64: 64,
21
+ }
22
+
23
+
24
+ # Autograd wrapper for sort kernel.
25
+ # NOTE: Does not support gradients.
26
+ class SortOp(torch.autograd.Function):
27
+
28
+ @staticmethod
29
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
30
+ if end_bit is None:
31
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
32
+ x_out = torch.empty_like(x)
33
+ iota_out = torch.empty_like(x)
34
+ ops.sort(x, end_bit, x_out, iota_out)
35
+ return (x_out, iota_out)
36
+
37
+
38
+ sort = SortOp.apply
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import numpy as np
7
+ import torch
8
+ from absl.testing import parameterized
9
+
10
+ from megablocks import ops
11
+
12
+ _SORT_TESTS = (
13
+ (16384, torch.int32, None),
14
+ (16384, torch.int32, 2),
15
+ (16384, torch.int32, 128),
16
+ )
17
+
18
+ _BASELINE_SORT_TESTS = ((16384,),)
19
+
20
+
21
+ def numpy_dtype(dtype):
22
+ types = {
23
+ torch.int16: np.int16,
24
+ torch.int32: np.int32,
25
+ torch.int64: np.int64,
26
+ }
27
+ return types[dtype]
28
+
29
+
30
+ def benchmark_function(fn, iterations=10):
31
+ # Run once to get rid of startup overhead.
32
+ fn()
33
+ times = []
34
+ for _ in range(iterations):
35
+ start = torch.cuda.Event(enable_timing=True)
36
+ end = torch.cuda.Event(enable_timing=True)
37
+ start.record()
38
+ fn()
39
+ end.record()
40
+ torch.cuda.synchronize()
41
+ times.append(start.elapsed_time(end))
42
+ times = np.array(times)
43
+ return times.mean(), times.std(), times.max(), times.min()
44
+
45
+
46
+ def log_benchmark(arguments, mean_t, std_t):
47
+ print('=' * 60)
48
+ print('Benchmark Parameters:')
49
+ for (key, value) in arguments.items():
50
+ print(f'{key} = {value}')
51
+ print('Results:')
52
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
53
+ print('=' * 60)
54
+
55
+
56
+ class SortBenchmark(parameterized.TestCase):
57
+
58
+ @parameterized.parameters(*_SORT_TESTS)
59
+ def testSort(self, n, dtype, max_val):
60
+ if max_val is None:
61
+ max_val = np.iinfo(numpy_dtype(dtype)).max
62
+ end_bit = int(np.ceil(np.log2(max_val)))
63
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
64
+
65
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
66
+ arguments = {
67
+ 'n': n,
68
+ 'dtype': dtype,
69
+ 'max_val': max_val,
70
+ }
71
+ log_benchmark(arguments, mean_t, std_t)
72
+
73
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
74
+ def testTorchSort(self, n):
75
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
76
+
77
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
78
+ arguments = {
79
+ 'n': n,
80
+ }
81
+ log_benchmark(arguments, mean_t, std_t)
82
+
83
+
84
+ if __name__ == '__main__':
85
+ unittest.main()
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import torch
4
+
5
+
6
+ def sum(x: torch.Tensor, dim: int = 0):
7
+ if x.shape[dim] == 1:
8
+ return x.squeeze(dim=dim)
9
+ return x.sum(dim=dim)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ from megablocks._ops import ops # type: ignore
14
+ except ModuleNotFoundError as e:
15
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
+
17
+
18
+ # Autograd wrapper for topology kernel.
19
+ # NOTE: Does not support gradients.
20
+ class TopologyOp(torch.autograd.Function):
21
+
22
+ @staticmethod
23
+ def forward(
24
+ ctx: Any,
25
+ padded_bins: torch.Tensor,
26
+ block_size: int,
27
+ output_block_rows: int,
28
+ output_block_columns: int,
29
+ ):
30
+ out = torch.empty(
31
+ output_block_rows * output_block_columns,
32
+ dtype=torch.int16,
33
+ device=padded_bins.device,
34
+ )
35
+ ops.indices(
36
+ padded_bins,
37
+ block_size,
38
+ output_block_rows,
39
+ output_block_columns,
40
+ out,
41
+ )
42
+ return out
43
+
44
+
45
+ topology = TopologyOp.apply
build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from ._ops import ops
7
+
8
+ from megablocks.layers.arguments import Arguments
9
+ from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
10
+ from megablocks.layers.glu import SparseGLU
11
+ from megablocks.layers.mlp import MLP, SparseMLP
12
+ from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
13
+
14
+ # This section contains the direct kernel exports (not inlcuded in the original code)
15
+ def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
16
+ """
17
+ Compute exclusive cumulative sum along the specified dimension.
18
+
19
+ Args:
20
+ x: Input tensor
21
+ dim: Dimension along which to compute cumsum
22
+ out: Output tensor (modified in-place)
23
+
24
+ Returns:
25
+ The output tensor
26
+ """
27
+ result = ops.exclusive_cumsum(x, dim)
28
+ out.copy_(result)
29
+ return out
30
+
31
+
32
+ def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Compute inclusive cumulative sum along the specified dimension.
35
+
36
+ Args:
37
+ x: Input tensor
38
+ dim: Dimension along which to compute cumsum
39
+ out: Output tensor (modified in-place)
40
+
41
+ Returns:
42
+ The output tensor
43
+ """
44
+ result = ops.inclusive_cumsum(x, dim)
45
+ out.copy_(result)
46
+ return out
47
+
48
+
49
+ def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
50
+ """
51
+ Compute histogram of input tensor values.
52
+
53
+ Args:
54
+ x: Input tensor
55
+ num_bins: Number of histogram bins
56
+
57
+ Returns:
58
+ Histogram tensor with counts for each bin
59
+ """
60
+ return ops.histogram(x, num_bins)
61
+
62
+
63
+ def indices(
64
+ padded_bins: torch.Tensor,
65
+ block_size: int,
66
+ output_block_rows: int,
67
+ output_block_columns: int,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Construct indices from padded bins for sparse operations.
71
+
72
+ Args:
73
+ padded_bins: Tensor containing bin boundaries
74
+ block_size: Size of each block
75
+ output_block_rows: Number of rows in output blocks
76
+ output_block_columns: Number of columns in output blocks
77
+
78
+ Returns:
79
+ Tensor containing constructed indices
80
+ """
81
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
82
+
83
+
84
+ def replicate_forward(
85
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
86
+ ) -> torch.Tensor:
87
+ """
88
+ Forward pass of replicate operation - replicate values according to bin sizes.
89
+
90
+ Args:
91
+ x: Input tensor with values to replicate
92
+ bins: Tensor containing bin sizes
93
+ out: Output tensor (modified in-place)
94
+
95
+ Returns:
96
+ The output tensor
97
+ """
98
+ return ops.replicate_forward(x, bins, out)
99
+
100
+
101
+ def replicate_backward(
102
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
103
+ ) -> torch.Tensor:
104
+ """
105
+ Backward pass of replicate operation - reduce gradients back to bins.
106
+
107
+ Args:
108
+ grad: Gradient tensor to reduce
109
+ bins: Tensor containing bin sizes
110
+ out: Output tensor (modified in-place)
111
+
112
+ Returns:
113
+ The output tensor
114
+ """
115
+ return ops.replicate_backward(grad, bins, out)
116
+
117
+
118
+ def sort(
119
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
120
+ ) -> torch.Tensor:
121
+ """
122
+ Radix sort with index tracking.
123
+
124
+ Args:
125
+ x: Input tensor to sort
126
+ end_bit: Number of bits to consider in sorting
127
+ x_out: Output tensor for sorted values
128
+ iota_out: Output tensor for sorted indices
129
+
130
+ Returns:
131
+ The sorted values tensor
132
+ """
133
+ return ops.sort(x, end_bit, x_out, iota_out)
134
+
135
+
136
+ # Convenience functions for common use cases
137
+ def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
138
+ """
139
+ Compute cumulative sum with automatic output allocation.
140
+
141
+ Args:
142
+ x: Input tensor
143
+ dim: Dimension along which to compute cumsum (default: last dimension)
144
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
145
+
146
+ Returns:
147
+ New tensor containing the cumulative sum
148
+ """
149
+ out = torch.empty_like(x)
150
+ if exclusive:
151
+ return exclusive_cumsum(x, dim, out)
152
+ else:
153
+ return inclusive_cumsum(x, dim, out)
154
+
155
+
156
+ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
157
+ """
158
+ Sort tensor and return both sorted values and indices.
159
+
160
+ Args:
161
+ x: Input tensor to sort
162
+ end_bit: Number of bits to consider in sorting
163
+
164
+ Returns:
165
+ Tuple of (sorted_values, sorted_indices)
166
+ """
167
+ x_out = torch.empty_like(x)
168
+ iota_out = torch.empty_like(x)
169
+ sort(x, end_bit, x_out, iota_out)
170
+ return x_out, iota_out
171
+
172
+
173
+ # Export public API
174
+ __all__ = [
175
+ # Direct kernel exports
176
+ "exclusive_cumsum",
177
+ "inclusive_cumsum",
178
+ "histogram",
179
+ "indices",
180
+ "replicate_forward",
181
+ "replicate_backward",
182
+ "sort",
183
+ "cumsum",
184
+ "argsort",
185
+ # Original exports
186
+ "Arguments",
187
+ "ParallelDroplessMLP",
188
+ "dMoE",
189
+ "SparseGLU",
190
+ "MLP",
191
+ "SparseMLP",
192
+ "MoE",
193
+ "ParallelMLP",
194
+ "get_load_balancing_loss",
195
+ ]
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12cf047c4bcb5f368f490ada3dafcecd1242a443ff5ded94c09d62548922098c
3
+ size 11795992
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _megablocks_359242d
3
+ ops = torch.ops._megablocks_359242d
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_megablocks_359242d::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_version.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """The MegaBlocks Version."""
5
+
6
+ __version__ = '0.11.0.dev0'