drbh
commited on
Commit
·
a585153
1
Parent(s):
359242d
feat: add build output
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py +195 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +3 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +9 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py +6 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py +2 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py +543 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py +23 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py +35 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +26 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py +10 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py +33 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py +54 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py +100 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py +26 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py +42 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py +327 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py +43 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py +223 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py +102 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py +574 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py +475 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py +93 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py +114 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py +30 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py +35 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +60 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +37 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +59 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py +52 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +38 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py +27 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +78 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +403 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +55 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +98 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +66 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +149 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py +10 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py +36 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py +14 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +72 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py +38 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +85 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py +9 -0
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py +45 -0
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py +195 -0
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +3 -0
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +9 -0
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/_version.py +6 -0
.gitattributes
CHANGED
@@ -32,4 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from ._ops import ops
|
7 |
+
|
8 |
+
from megablocks.layers.arguments import Arguments
|
9 |
+
from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
|
10 |
+
from megablocks.layers.glu import SparseGLU
|
11 |
+
from megablocks.layers.mlp import MLP, SparseMLP
|
12 |
+
from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
|
13 |
+
|
14 |
+
# This section contains the direct kernel exports (not inlcuded in the original code)
|
15 |
+
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
16 |
+
"""
|
17 |
+
Compute exclusive cumulative sum along the specified dimension.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
x: Input tensor
|
21 |
+
dim: Dimension along which to compute cumsum
|
22 |
+
out: Output tensor (modified in-place)
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
The output tensor
|
26 |
+
"""
|
27 |
+
result = ops.exclusive_cumsum(x, dim)
|
28 |
+
out.copy_(result)
|
29 |
+
return out
|
30 |
+
|
31 |
+
|
32 |
+
def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
33 |
+
"""
|
34 |
+
Compute inclusive cumulative sum along the specified dimension.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x: Input tensor
|
38 |
+
dim: Dimension along which to compute cumsum
|
39 |
+
out: Output tensor (modified in-place)
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
The output tensor
|
43 |
+
"""
|
44 |
+
result = ops.inclusive_cumsum(x, dim)
|
45 |
+
out.copy_(result)
|
46 |
+
return out
|
47 |
+
|
48 |
+
|
49 |
+
def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
|
50 |
+
"""
|
51 |
+
Compute histogram of input tensor values.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
x: Input tensor
|
55 |
+
num_bins: Number of histogram bins
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Histogram tensor with counts for each bin
|
59 |
+
"""
|
60 |
+
return ops.histogram(x, num_bins)
|
61 |
+
|
62 |
+
|
63 |
+
def indices(
|
64 |
+
padded_bins: torch.Tensor,
|
65 |
+
block_size: int,
|
66 |
+
output_block_rows: int,
|
67 |
+
output_block_columns: int,
|
68 |
+
) -> torch.Tensor:
|
69 |
+
"""
|
70 |
+
Construct indices from padded bins for sparse operations.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
padded_bins: Tensor containing bin boundaries
|
74 |
+
block_size: Size of each block
|
75 |
+
output_block_rows: Number of rows in output blocks
|
76 |
+
output_block_columns: Number of columns in output blocks
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
Tensor containing constructed indices
|
80 |
+
"""
|
81 |
+
return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
|
82 |
+
|
83 |
+
|
84 |
+
def replicate_forward(
|
85 |
+
x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
|
86 |
+
) -> torch.Tensor:
|
87 |
+
"""
|
88 |
+
Forward pass of replicate operation - replicate values according to bin sizes.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
x: Input tensor with values to replicate
|
92 |
+
bins: Tensor containing bin sizes
|
93 |
+
out: Output tensor (modified in-place)
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
The output tensor
|
97 |
+
"""
|
98 |
+
return ops.replicate_forward(x, bins, out)
|
99 |
+
|
100 |
+
|
101 |
+
def replicate_backward(
|
102 |
+
grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
|
103 |
+
) -> torch.Tensor:
|
104 |
+
"""
|
105 |
+
Backward pass of replicate operation - reduce gradients back to bins.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
grad: Gradient tensor to reduce
|
109 |
+
bins: Tensor containing bin sizes
|
110 |
+
out: Output tensor (modified in-place)
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
The output tensor
|
114 |
+
"""
|
115 |
+
return ops.replicate_backward(grad, bins, out)
|
116 |
+
|
117 |
+
|
118 |
+
def sort(
|
119 |
+
x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
|
120 |
+
) -> torch.Tensor:
|
121 |
+
"""
|
122 |
+
Radix sort with index tracking.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
x: Input tensor to sort
|
126 |
+
end_bit: Number of bits to consider in sorting
|
127 |
+
x_out: Output tensor for sorted values
|
128 |
+
iota_out: Output tensor for sorted indices
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
The sorted values tensor
|
132 |
+
"""
|
133 |
+
return ops.sort(x, end_bit, x_out, iota_out)
|
134 |
+
|
135 |
+
|
136 |
+
# Convenience functions for common use cases
|
137 |
+
def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
|
138 |
+
"""
|
139 |
+
Compute cumulative sum with automatic output allocation.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x: Input tensor
|
143 |
+
dim: Dimension along which to compute cumsum (default: last dimension)
|
144 |
+
exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
New tensor containing the cumulative sum
|
148 |
+
"""
|
149 |
+
out = torch.empty_like(x)
|
150 |
+
if exclusive:
|
151 |
+
return exclusive_cumsum(x, dim, out)
|
152 |
+
else:
|
153 |
+
return inclusive_cumsum(x, dim, out)
|
154 |
+
|
155 |
+
|
156 |
+
def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
|
157 |
+
"""
|
158 |
+
Sort tensor and return both sorted values and indices.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
x: Input tensor to sort
|
162 |
+
end_bit: Number of bits to consider in sorting
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
Tuple of (sorted_values, sorted_indices)
|
166 |
+
"""
|
167 |
+
x_out = torch.empty_like(x)
|
168 |
+
iota_out = torch.empty_like(x)
|
169 |
+
sort(x, end_bit, x_out, iota_out)
|
170 |
+
return x_out, iota_out
|
171 |
+
|
172 |
+
|
173 |
+
# Export public API
|
174 |
+
__all__ = [
|
175 |
+
# Direct kernel exports
|
176 |
+
"exclusive_cumsum",
|
177 |
+
"inclusive_cumsum",
|
178 |
+
"histogram",
|
179 |
+
"indices",
|
180 |
+
"replicate_forward",
|
181 |
+
"replicate_backward",
|
182 |
+
"sort",
|
183 |
+
"cumsum",
|
184 |
+
"argsort",
|
185 |
+
# Original exports
|
186 |
+
"Arguments",
|
187 |
+
"ParallelDroplessMLP",
|
188 |
+
"dMoE",
|
189 |
+
"SparseGLU",
|
190 |
+
"MLP",
|
191 |
+
"SparseMLP",
|
192 |
+
"MoE",
|
193 |
+
"ParallelMLP",
|
194 |
+
"get_load_balancing_loss",
|
195 |
+
]
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:36a918d61308a0acbc880516a42fcc2c4dc393f25655326e52128465c9501709
|
3 |
+
size 10456376
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _megablocks_359242d
|
3 |
+
ops = torch.ops._megablocks_359242d
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_megablocks_359242d::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""The MegaBlocks Version."""
|
5 |
+
|
6 |
+
__version__ = '0.11.0.dev0'
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py
ADDED
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
|
8 |
+
|
9 |
+
def assert_is_tensor(x, ndim):
|
10 |
+
if x.ndim != ndim:
|
11 |
+
raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
|
12 |
+
|
13 |
+
|
14 |
+
def assert_is_matrix(x):
|
15 |
+
assert_is_tensor(x, 2)
|
16 |
+
|
17 |
+
|
18 |
+
def assert_is_vector(x):
|
19 |
+
if x.ndim != 1:
|
20 |
+
raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
|
21 |
+
|
22 |
+
|
23 |
+
def assert_equal(a, b):
|
24 |
+
if a != b:
|
25 |
+
raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
|
26 |
+
|
27 |
+
|
28 |
+
# a: (tokens, hidden_size), real.
|
29 |
+
# indices: (tokens * top_k), integer.
|
30 |
+
# bin_ids: (tokens * top_k), integer.
|
31 |
+
# weights: (tokens * top_k), real.
|
32 |
+
# bins: (num_experts), integer.
|
33 |
+
# padded_bins: (num_experts), integer.
|
34 |
+
@triton.autotune(
|
35 |
+
configs=[
|
36 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
37 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
38 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
39 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
40 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
41 |
+
],
|
42 |
+
key=['NUM_COLUMNS'],
|
43 |
+
)
|
44 |
+
@triton.jit
|
45 |
+
def _padded_copy(
|
46 |
+
a,
|
47 |
+
b,
|
48 |
+
indices,
|
49 |
+
bin_ids,
|
50 |
+
weights,
|
51 |
+
bins,
|
52 |
+
padded_bins,
|
53 |
+
NUM_COLUMNS: tl.constexpr,
|
54 |
+
TOP_K: tl.constexpr,
|
55 |
+
BLOCK_X: tl.constexpr,
|
56 |
+
A_TO_B: tl.constexpr,
|
57 |
+
SCALE: tl.constexpr,
|
58 |
+
):
|
59 |
+
# Our index into array 'a'.
|
60 |
+
index_a = tl.load(indices + tl.program_id(0))
|
61 |
+
|
62 |
+
# One threadblock per row in 'a'. Array 'b' has greater or equal
|
63 |
+
# number of rows since they could be padded.
|
64 |
+
bin_idx = tl.load(bin_ids + tl.program_id(0))
|
65 |
+
|
66 |
+
# Now we know what bin we're assigned to, but we need to know how
|
67 |
+
# many threadblocks were assigned to earlier bins so we can offset
|
68 |
+
# in our bin properly.
|
69 |
+
offset_in_bin = tl.program_id(0)
|
70 |
+
if bin_idx > 0:
|
71 |
+
offset_in_bin -= tl.load(bins + bin_idx - 1)
|
72 |
+
|
73 |
+
# Load the starting index of our bin in array 'b'.
|
74 |
+
index_b = offset_in_bin
|
75 |
+
if bin_idx > 0:
|
76 |
+
index_b += tl.load(padded_bins + bin_idx - 1)
|
77 |
+
|
78 |
+
# Offset the input and output pointers.
|
79 |
+
#
|
80 |
+
# If we're going from A to B, divide the input index to copy
|
81 |
+
# the same input repeatedly. If we're going from B to A we
|
82 |
+
# need to reduce the result. Using atomics is slow, so we
|
83 |
+
# do the reduce step in a second kernel.
|
84 |
+
offset = index_a // TOP_K if A_TO_B else index_a
|
85 |
+
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
|
86 |
+
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
|
87 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
88 |
+
|
89 |
+
# Load the scale, if requested.
|
90 |
+
scale = tl.load(weights + index_a) if SCALE else 1
|
91 |
+
|
92 |
+
# Swap the pointers depending on the direction.
|
93 |
+
iptr = a if A_TO_B else b
|
94 |
+
optr = b if A_TO_B else a
|
95 |
+
|
96 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
97 |
+
for _ in range(iterations):
|
98 |
+
mask = offsets < NUM_COLUMNS
|
99 |
+
x = tl.load(iptr + offsets, mask=mask)
|
100 |
+
x = x.to(tl.float32) * scale.to(tl.float32)
|
101 |
+
|
102 |
+
tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
|
103 |
+
|
104 |
+
offsets += BLOCK_X
|
105 |
+
|
106 |
+
|
107 |
+
def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
|
108 |
+
# Validate the input shapes.
|
109 |
+
assert_is_matrix(x)
|
110 |
+
assert_is_vector(indices)
|
111 |
+
assert_is_vector(bin_ids)
|
112 |
+
assert_is_vector(bins)
|
113 |
+
assert_is_vector(padded_bins)
|
114 |
+
assert_equal(indices.shape[0], x.shape[0] * top_k)
|
115 |
+
assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
|
116 |
+
assert_equal(bins.size(), padded_bins.size())
|
117 |
+
|
118 |
+
if weights is not None:
|
119 |
+
assert_equal(weights.shape[0], x.shape[0] * top_k)
|
120 |
+
|
121 |
+
# NOTE: Because of the padding, the output size is dynamic.
|
122 |
+
# We load the final padded bin bound to get the output rows.
|
123 |
+
output_rows = padded_bins[-1].cpu().item()
|
124 |
+
out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
|
125 |
+
_padded_copy[(indices.shape[0],)](
|
126 |
+
x,
|
127 |
+
out,
|
128 |
+
indices,
|
129 |
+
bin_ids,
|
130 |
+
weights,
|
131 |
+
bins,
|
132 |
+
padded_bins,
|
133 |
+
NUM_COLUMNS=x.shape[1],
|
134 |
+
A_TO_B=True,
|
135 |
+
TOP_K=top_k,
|
136 |
+
SCALE=weights is not None,
|
137 |
+
)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
def gather(x, indices, bin_ids, weights, bins, top_k):
|
142 |
+
# Validate the input shapes.
|
143 |
+
assert_is_matrix(x)
|
144 |
+
assert_is_vector(indices)
|
145 |
+
assert_is_vector(bin_ids)
|
146 |
+
assert_is_vector(bins)
|
147 |
+
assert_equal(indices.shape[0], x.shape[0] * top_k)
|
148 |
+
assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
|
149 |
+
|
150 |
+
if weights is not None:
|
151 |
+
assert_equal(weights.shape[0], x.shape[0] * top_k)
|
152 |
+
|
153 |
+
# NOTE: There is no padding so the output rows equals the
|
154 |
+
# input rows multiplied by top_k.
|
155 |
+
output_rows = x.shape[0] * top_k
|
156 |
+
out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
|
157 |
+
_padded_copy[(indices.shape[0],)](
|
158 |
+
x,
|
159 |
+
out,
|
160 |
+
indices,
|
161 |
+
bin_ids,
|
162 |
+
weights,
|
163 |
+
bins,
|
164 |
+
bins,
|
165 |
+
NUM_COLUMNS=x.shape[1],
|
166 |
+
A_TO_B=True,
|
167 |
+
TOP_K=top_k,
|
168 |
+
SCALE=weights is not None,
|
169 |
+
)
|
170 |
+
return out
|
171 |
+
|
172 |
+
|
173 |
+
def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
|
174 |
+
# Validate the input shapes.
|
175 |
+
assert_is_matrix(x)
|
176 |
+
assert_is_vector(indices)
|
177 |
+
assert_is_vector(bin_ids)
|
178 |
+
assert_is_vector(bins)
|
179 |
+
assert_is_vector(padded_bins)
|
180 |
+
assert_equal(indices.shape[0], bin_ids.shape[0])
|
181 |
+
assert_equal(bins.size(), padded_bins.size())
|
182 |
+
|
183 |
+
if weights is not None:
|
184 |
+
assert_equal(indices.shape[0], weights.shape[0])
|
185 |
+
|
186 |
+
tokens = indices.shape[0] // top_k
|
187 |
+
out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
|
188 |
+
_padded_copy[(indices.shape[0],)](
|
189 |
+
out,
|
190 |
+
x,
|
191 |
+
indices,
|
192 |
+
bin_ids,
|
193 |
+
weights,
|
194 |
+
bins,
|
195 |
+
padded_bins,
|
196 |
+
NUM_COLUMNS=x.shape[1],
|
197 |
+
A_TO_B=False,
|
198 |
+
TOP_K=top_k,
|
199 |
+
SCALE=weights is not None,
|
200 |
+
)
|
201 |
+
|
202 |
+
# Reduce along the top-k dimension, if needed.
|
203 |
+
return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
|
204 |
+
|
205 |
+
|
206 |
+
def scatter(x, indices, bin_ids, weights, bins, top_k):
|
207 |
+
return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
|
208 |
+
|
209 |
+
|
210 |
+
# x: (tokens, top_k, hidden_size), real
|
211 |
+
# grad: (tokens, hidden_size), real.
|
212 |
+
# wgrad: (tokens, top_k), real.
|
213 |
+
# indices: (tokens * top_k), integer.
|
214 |
+
# bin_ids: (tokens * top_k), integer.
|
215 |
+
# bins: (num_experts), integer.
|
216 |
+
# padded_bins: (num_experts), integer.
|
217 |
+
@triton.autotune(
|
218 |
+
configs=[
|
219 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
220 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
221 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
222 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
223 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
224 |
+
],
|
225 |
+
key=['NUM_COLUMNS'],
|
226 |
+
)
|
227 |
+
@triton.jit
|
228 |
+
def _padded_copy_wgrad(
|
229 |
+
x,
|
230 |
+
grad,
|
231 |
+
wgrad,
|
232 |
+
indices,
|
233 |
+
bin_ids,
|
234 |
+
bins,
|
235 |
+
padded_bins,
|
236 |
+
NUM_COLUMNS: tl.constexpr,
|
237 |
+
TOP_K: tl.constexpr,
|
238 |
+
BLOCK_X: tl.constexpr,
|
239 |
+
):
|
240 |
+
# Our index into 'tokens * top_k'.
|
241 |
+
index_out = tl.load(indices + tl.program_id(0))
|
242 |
+
|
243 |
+
# One threadblock per row in 'a'. Array 'b' has greater or equal
|
244 |
+
# number of rows since they could be padded.
|
245 |
+
bin_idx = tl.load(bin_ids + tl.program_id(0))
|
246 |
+
|
247 |
+
# Now we know what bin we're assigned to, but we need to know how
|
248 |
+
# many threadblocks were assigned to earlier bins so we can offset
|
249 |
+
# in our bin properly.
|
250 |
+
offset_in_bin = tl.program_id(0)
|
251 |
+
if bin_idx > 0:
|
252 |
+
offset_in_bin -= tl.load(bins + bin_idx - 1)
|
253 |
+
|
254 |
+
# Load the starting index of our bin in array 'x'.
|
255 |
+
index_x = offset_in_bin
|
256 |
+
if bin_idx > 0:
|
257 |
+
index_x += tl.load(padded_bins + bin_idx - 1)
|
258 |
+
|
259 |
+
# Offset the input and output pointers.
|
260 |
+
wgrad += index_out
|
261 |
+
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
|
262 |
+
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
|
263 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
264 |
+
|
265 |
+
acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
|
266 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
267 |
+
for _ in range(iterations):
|
268 |
+
mask = offsets < NUM_COLUMNS
|
269 |
+
data = tl.load(x + offsets, mask=mask).to(tl.float32)
|
270 |
+
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
|
271 |
+
acc += data * scale
|
272 |
+
offsets += BLOCK_X
|
273 |
+
|
274 |
+
# Reduce to get the final result and store.
|
275 |
+
out = tl.sum(acc).to(wgrad.dtype.element_ty)
|
276 |
+
tl.store(wgrad, out)
|
277 |
+
|
278 |
+
|
279 |
+
def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
|
280 |
+
# Validate the input shapes.
|
281 |
+
assert_is_matrix(x)
|
282 |
+
assert_is_matrix(grad)
|
283 |
+
assert_is_vector(indices)
|
284 |
+
assert_is_vector(bin_ids)
|
285 |
+
assert_is_vector(bins)
|
286 |
+
assert_is_vector(padded_bins)
|
287 |
+
assert_equal(indices.shape[0], bin_ids.shape[0])
|
288 |
+
assert_equal(bins.size(), padded_bins.size())
|
289 |
+
|
290 |
+
tokens = indices.shape[0] // top_k
|
291 |
+
out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
|
292 |
+
_padded_copy_wgrad[(indices.shape[0],)](
|
293 |
+
x,
|
294 |
+
grad,
|
295 |
+
out,
|
296 |
+
indices,
|
297 |
+
bin_ids,
|
298 |
+
bins,
|
299 |
+
padded_bins,
|
300 |
+
NUM_COLUMNS=x.shape[1],
|
301 |
+
TOP_K=top_k,
|
302 |
+
)
|
303 |
+
return out
|
304 |
+
|
305 |
+
|
306 |
+
def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
|
307 |
+
return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
|
308 |
+
|
309 |
+
|
310 |
+
# a: (tokens, hidden_size), real.
|
311 |
+
# b: (num_experts, expert_capacity, num_columns), real.
|
312 |
+
# indices: (tokens * top_k), integer.
|
313 |
+
# weights: (tokens * top_k), real.
|
314 |
+
# bins: (num_experts), integer.
|
315 |
+
@triton.autotune(
|
316 |
+
configs=[
|
317 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
318 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
319 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
320 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
321 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
322 |
+
],
|
323 |
+
key=['NUM_COLUMNS'],
|
324 |
+
)
|
325 |
+
@triton.jit
|
326 |
+
def _binned_copy(
|
327 |
+
a,
|
328 |
+
b,
|
329 |
+
num_experts,
|
330 |
+
expert_capacity,
|
331 |
+
indices,
|
332 |
+
weights,
|
333 |
+
bins,
|
334 |
+
NUM_COLUMNS: tl.constexpr,
|
335 |
+
TOP_K: tl.constexpr,
|
336 |
+
BLOCK_X: tl.constexpr,
|
337 |
+
A_TO_B: tl.constexpr,
|
338 |
+
SCALE: tl.constexpr,
|
339 |
+
):
|
340 |
+
# Load our indices into the output.
|
341 |
+
expert_idx = tl.program_id(0)
|
342 |
+
entry_idx = tl.program_id(1)
|
343 |
+
|
344 |
+
# Calculate our offset into the output.
|
345 |
+
index_b = expert_idx * expert_capacity + entry_idx
|
346 |
+
|
347 |
+
# Load the index bounds for our bin and calculate
|
348 |
+
# the number of tokens assigned to our expert.
|
349 |
+
start = 0
|
350 |
+
if expert_idx > 0:
|
351 |
+
start = tl.load(bins + expert_idx - 1)
|
352 |
+
end = tl.load(bins + expert_idx)
|
353 |
+
num_tokens = end - start
|
354 |
+
|
355 |
+
# Calculate our offset into the input. If we don't
|
356 |
+
# have an input exit early.
|
357 |
+
if entry_idx >= num_tokens:
|
358 |
+
return
|
359 |
+
index_a = tl.load(indices + start + entry_idx)
|
360 |
+
|
361 |
+
# Offset the input and output pointers.
|
362 |
+
#
|
363 |
+
# If we're going from A to B, divide the input index to copy
|
364 |
+
# the same input repeatedly. If we're going from B to A we
|
365 |
+
# need to reduce the result. Using atomics is slow, so we
|
366 |
+
# do the reduce step in a second kernel.
|
367 |
+
offset = index_a // TOP_K if A_TO_B else index_a
|
368 |
+
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
|
369 |
+
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
|
370 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
371 |
+
|
372 |
+
# Load the scale, if requested.
|
373 |
+
scale = tl.load(weights + index_a) if SCALE else 1
|
374 |
+
|
375 |
+
# Swap the pointers depending on the direction.
|
376 |
+
#
|
377 |
+
# NOTE: We need to zero the output in both directions.
|
378 |
+
iptr = a if A_TO_B else b
|
379 |
+
optr = b if A_TO_B else a
|
380 |
+
|
381 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
382 |
+
for _ in range(iterations):
|
383 |
+
mask = offsets < NUM_COLUMNS
|
384 |
+
x = tl.load(iptr + offsets, mask=mask)
|
385 |
+
x = x.to(tl.float32) * scale.to(tl.float32)
|
386 |
+
|
387 |
+
tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
|
388 |
+
|
389 |
+
offsets += BLOCK_X
|
390 |
+
|
391 |
+
|
392 |
+
def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
393 |
+
# Validate the input shapes.
|
394 |
+
assert_is_matrix(x)
|
395 |
+
assert_is_vector(indices)
|
396 |
+
assert_is_vector(bins)
|
397 |
+
assert_equal(indices.shape[0], x.shape[0] * top_k)
|
398 |
+
|
399 |
+
if weights is not None:
|
400 |
+
assert_equal(weights.shape[0], x.shape[0] * top_k)
|
401 |
+
|
402 |
+
num_experts = bins.shape[0]
|
403 |
+
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
404 |
+
|
405 |
+
_binned_copy[(num_experts, expert_capacity)](
|
406 |
+
x,
|
407 |
+
out,
|
408 |
+
num_experts,
|
409 |
+
expert_capacity,
|
410 |
+
indices,
|
411 |
+
weights,
|
412 |
+
bins,
|
413 |
+
NUM_COLUMNS=x.shape[1],
|
414 |
+
A_TO_B=True,
|
415 |
+
TOP_K=top_k,
|
416 |
+
SCALE=weights is not None,
|
417 |
+
)
|
418 |
+
return out
|
419 |
+
|
420 |
+
|
421 |
+
def binned_scatter(x, indices, weights, bins, top_k):
|
422 |
+
# Validate the input shapes.
|
423 |
+
assert_is_tensor(x, 3)
|
424 |
+
assert_is_vector(indices)
|
425 |
+
assert_is_vector(bins)
|
426 |
+
assert_equal(bins.shape[0], x.shape[0])
|
427 |
+
|
428 |
+
if weights is not None:
|
429 |
+
assert_equal(indices.shape[0], weights.shape[0])
|
430 |
+
|
431 |
+
num_experts, expert_capacity, hidden_size = x.shape
|
432 |
+
tokens = indices.shape[0] // top_k
|
433 |
+
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
434 |
+
_binned_copy[(num_experts, expert_capacity)](
|
435 |
+
out,
|
436 |
+
x,
|
437 |
+
num_experts,
|
438 |
+
expert_capacity,
|
439 |
+
indices,
|
440 |
+
weights,
|
441 |
+
bins,
|
442 |
+
NUM_COLUMNS=hidden_size,
|
443 |
+
A_TO_B=False,
|
444 |
+
TOP_K=top_k,
|
445 |
+
SCALE=weights is not None,
|
446 |
+
)
|
447 |
+
|
448 |
+
# Reduce along the top-k dimension, if needed.
|
449 |
+
return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
|
450 |
+
|
451 |
+
|
452 |
+
# a: (tokens, hidden_size), real.
|
453 |
+
# b: (num_experts, expert_capacity, num_columns), real.
|
454 |
+
# indices: (tokens * top_k), integer.
|
455 |
+
# weights: (tokens * top_k), real.
|
456 |
+
# bins: (num_experts), integer.
|
457 |
+
@triton.autotune(
|
458 |
+
configs=[
|
459 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
460 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
461 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
462 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
463 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
464 |
+
],
|
465 |
+
key=['NUM_COLUMNS'],
|
466 |
+
)
|
467 |
+
@triton.jit
|
468 |
+
def _binned_copy_wgrad(
|
469 |
+
x,
|
470 |
+
grad,
|
471 |
+
wgrad,
|
472 |
+
num_experts,
|
473 |
+
expert_capacity,
|
474 |
+
indices,
|
475 |
+
bins,
|
476 |
+
NUM_COLUMNS: tl.constexpr,
|
477 |
+
TOP_K: tl.constexpr,
|
478 |
+
BLOCK_X: tl.constexpr,
|
479 |
+
):
|
480 |
+
# Load our indices into the output.
|
481 |
+
expert_idx = tl.program_id(0)
|
482 |
+
entry_idx = tl.program_id(1)
|
483 |
+
|
484 |
+
# Calculate our offset into the output.
|
485 |
+
index_x = expert_idx * expert_capacity + entry_idx
|
486 |
+
|
487 |
+
# Load the index bounds for our bin and calculate
|
488 |
+
# the number of tokens assigned to our expert.
|
489 |
+
start = 0
|
490 |
+
if expert_idx > 0:
|
491 |
+
start = tl.load(bins + expert_idx - 1)
|
492 |
+
end = tl.load(bins + expert_idx)
|
493 |
+
num_tokens = end - start
|
494 |
+
|
495 |
+
# Calculate our offset into the input. If we don't
|
496 |
+
# have an input exit early.
|
497 |
+
if entry_idx >= num_tokens:
|
498 |
+
return
|
499 |
+
index_out = tl.load(indices + start + entry_idx)
|
500 |
+
|
501 |
+
# Offset the input and output pointers.
|
502 |
+
wgrad += index_out
|
503 |
+
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
|
504 |
+
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
|
505 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
506 |
+
|
507 |
+
acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
|
508 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
509 |
+
for _ in range(iterations):
|
510 |
+
mask = offsets < NUM_COLUMNS
|
511 |
+
data = tl.load(x + offsets, mask=mask).to(tl.float32)
|
512 |
+
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
|
513 |
+
acc += data * scale
|
514 |
+
offsets += BLOCK_X
|
515 |
+
|
516 |
+
# Reduce to get the final result and store.
|
517 |
+
out = tl.sum(acc).to(wgrad.dtype.element_ty)
|
518 |
+
tl.store(wgrad, out)
|
519 |
+
|
520 |
+
|
521 |
+
def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
522 |
+
# Validate the input shapes.
|
523 |
+
assert_is_tensor(x, 3)
|
524 |
+
assert_is_matrix(grad)
|
525 |
+
assert_is_vector(indices)
|
526 |
+
assert_is_vector(bins)
|
527 |
+
assert_equal(bins.shape[0], x.shape[0])
|
528 |
+
|
529 |
+
num_experts, expert_capacity, hidden_size = x.shape
|
530 |
+
tokens = indices.shape[0] // top_k
|
531 |
+
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
532 |
+
_binned_copy_wgrad[(num_experts, expert_capacity)](
|
533 |
+
x,
|
534 |
+
grad,
|
535 |
+
out,
|
536 |
+
num_experts,
|
537 |
+
expert_capacity,
|
538 |
+
indices,
|
539 |
+
bins,
|
540 |
+
NUM_COLUMNS=hidden_size,
|
541 |
+
TOP_K=top_k,
|
542 |
+
)
|
543 |
+
return out
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from megablocks_moe.megablocks import (
|
2 |
+
MoE,
|
3 |
+
dMoE,
|
4 |
+
get_load_balancing_loss,
|
5 |
+
ParallelMLP,
|
6 |
+
ParallelDroplessMLP,
|
7 |
+
SparseMLP,
|
8 |
+
MLP,
|
9 |
+
SparseGLU,
|
10 |
+
Arguments,
|
11 |
+
)
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"MoE",
|
15 |
+
"dMoE",
|
16 |
+
"get_load_balancing_loss",
|
17 |
+
"ParallelMLP",
|
18 |
+
"ParallelDroplessMLP",
|
19 |
+
"SparseMLP",
|
20 |
+
"MLP",
|
21 |
+
"SparseGLU",
|
22 |
+
"Arguments",
|
23 |
+
]
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def log_benchmark(name, arguments, time, std):
|
9 |
+
print('=' * 60)
|
10 |
+
print(f'{name} Benchmark')
|
11 |
+
print('Benchmark Parameters:')
|
12 |
+
for (key, value) in arguments.items():
|
13 |
+
print(f'{key} = {value}')
|
14 |
+
print('Results:')
|
15 |
+
print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
|
16 |
+
print('=' * 60)
|
17 |
+
|
18 |
+
|
19 |
+
def benchmark_function(fn, iterations=100, warmup=10):
|
20 |
+
# Warmup iterations.
|
21 |
+
for _ in range(warmup):
|
22 |
+
fn()
|
23 |
+
|
24 |
+
times = []
|
25 |
+
for i in range(iterations):
|
26 |
+
start = torch.cuda.Event(enable_timing=True)
|
27 |
+
end = torch.cuda.Event(enable_timing=True)
|
28 |
+
|
29 |
+
start.record()
|
30 |
+
fn()
|
31 |
+
end.record()
|
32 |
+
|
33 |
+
torch.cuda.synchronize()
|
34 |
+
times.append(start.elapsed_time(end))
|
35 |
+
return np.mean(times), np.std(times)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
_grouped_gemm_is_available: bool = False
|
6 |
+
try:
|
7 |
+
import grouped_gemm
|
8 |
+
_grouped_gemm_is_available = True
|
9 |
+
except ImportError as error:
|
10 |
+
warnings.warn('Grouped GEMM not available.')
|
11 |
+
|
12 |
+
|
13 |
+
def grouped_gemm_is_available():
|
14 |
+
return _grouped_gemm_is_available
|
15 |
+
|
16 |
+
|
17 |
+
def assert_grouped_gemm_is_available():
|
18 |
+
msg = (
|
19 |
+
'Grouped GEMM not available. Please run '
|
20 |
+
'`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
|
21 |
+
)
|
22 |
+
assert _grouped_gemm_is_available, msg
|
23 |
+
|
24 |
+
|
25 |
+
backend = grouped_gemm.backend if grouped_gemm_is_available() else None
|
26 |
+
ops = grouped_gemm.ops if grouped_gemm_is_available() else None
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# from megablocks.layers.dmoe import dMoE
|
5 |
+
from megablocks.layers.moe import MoE
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'MoE',
|
9 |
+
# 'dMoE',
|
10 |
+
]
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any, Callable, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from stk import Matrix
|
8 |
+
|
9 |
+
|
10 |
+
def act_fn(
|
11 |
+
x: Matrix,
|
12 |
+
function: Callable,
|
13 |
+
return_grad_fn: bool = False,
|
14 |
+
**kwargs,
|
15 |
+
) -> Union[tuple[Matrix, Any] | Matrix]:
|
16 |
+
assert isinstance(x, Matrix)
|
17 |
+
with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
|
18 |
+
if return_grad_fn:
|
19 |
+
x.data.requires_grad = True
|
20 |
+
out = function(x.data, **kwargs)
|
21 |
+
y = Matrix(
|
22 |
+
x.size(),
|
23 |
+
out,
|
24 |
+
x.row_indices,
|
25 |
+
x.column_indices,
|
26 |
+
x.offsets,
|
27 |
+
x.column_indices_t,
|
28 |
+
x.offsets_t,
|
29 |
+
x.block_offsets_t,
|
30 |
+
)
|
31 |
+
if return_grad_fn:
|
32 |
+
return y, out.backward
|
33 |
+
return y
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
class AllToAllOp(torch.autograd.Function):
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
|
12 |
+
out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
|
13 |
+
|
14 |
+
ctx.input_shape = x.shape
|
15 |
+
ctx.output_split_sizes = output_split_sizes
|
16 |
+
ctx.input_split_sizes = input_split_sizes
|
17 |
+
ctx.group = group
|
18 |
+
handle = dist.all_to_all_single(
|
19 |
+
out,
|
20 |
+
x,
|
21 |
+
output_split_sizes=output_split_sizes,
|
22 |
+
input_split_sizes=input_split_sizes,
|
23 |
+
group=group,
|
24 |
+
async_op=async_op,
|
25 |
+
)
|
26 |
+
return out, handle
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx, grad, _):
|
30 |
+
if ctx.needs_input_grad[0]:
|
31 |
+
out = torch.empty(
|
32 |
+
ctx.input_shape,
|
33 |
+
device=grad.device,
|
34 |
+
dtype=grad.dtype,
|
35 |
+
)
|
36 |
+
dist.all_to_all_single(
|
37 |
+
out,
|
38 |
+
grad,
|
39 |
+
output_split_sizes=ctx.input_split_sizes,
|
40 |
+
input_split_sizes=ctx.output_split_sizes,
|
41 |
+
group=ctx.group,
|
42 |
+
)
|
43 |
+
return out, None, None, None, None
|
44 |
+
return None, None, None, None, None
|
45 |
+
|
46 |
+
|
47 |
+
def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
|
48 |
+
return AllToAllOp.apply(
|
49 |
+
x,
|
50 |
+
output_split_sizes,
|
51 |
+
input_split_sizes,
|
52 |
+
group,
|
53 |
+
async_op,
|
54 |
+
)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import dataclasses
|
5 |
+
from functools import partial
|
6 |
+
from typing import Any, Callable, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
import megablocks.grouped_gemm_util as grouped_gemm
|
13 |
+
|
14 |
+
# Type annotation for in-place Tensor initialization function.
|
15 |
+
InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
|
16 |
+
|
17 |
+
_ALLOWED_BITWIDTHS = (-1, 4, 8)
|
18 |
+
|
19 |
+
DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
|
20 |
+
|
21 |
+
|
22 |
+
@dataclasses.dataclass
|
23 |
+
class Arguments:
|
24 |
+
# Model arguments.
|
25 |
+
hidden_size: int = 1024
|
26 |
+
ffn_hidden_size: int = 4096
|
27 |
+
num_layers: int = 1
|
28 |
+
bias: bool = True
|
29 |
+
return_bias: bool = True
|
30 |
+
activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
|
31 |
+
|
32 |
+
# MoE arguments.
|
33 |
+
moe_num_experts: int = 1
|
34 |
+
moe_top_k: int = 1
|
35 |
+
moe_capacity_factor: int = 1
|
36 |
+
moe_normalize_expert_weights: Optional[Union[int, float]] = None
|
37 |
+
moe_loss_weight: float = 0.1
|
38 |
+
moe_jitter_eps: Optional[float] = None
|
39 |
+
moe_lbl_in_fp32: bool = False
|
40 |
+
|
41 |
+
# Parallelism arguments.
|
42 |
+
moe_expert_model_parallelism: bool = False
|
43 |
+
expert_parallel_group: Optional[dist.ProcessGroup] = None
|
44 |
+
pipeline_model_parallel_size: int = 1
|
45 |
+
num_layers_per_virtual_pipeline_stage: Optional[int] = None
|
46 |
+
|
47 |
+
# Compute arguments.
|
48 |
+
memory_optimized_mlp: bool = False
|
49 |
+
mlp_type: str = 'mlp'
|
50 |
+
mlp_impl: str = 'sparse'
|
51 |
+
|
52 |
+
# Initialization arguments.
|
53 |
+
fp16: bool = True
|
54 |
+
bf16: bool = False
|
55 |
+
device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
|
56 |
+
init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
|
57 |
+
output_layer_init_method: InitFn = init_method
|
58 |
+
|
59 |
+
# Benchmarking arguments.
|
60 |
+
uniform_expert_assignment: bool = False
|
61 |
+
|
62 |
+
# shared expert arguments
|
63 |
+
shared_expert: bool = False # enable using shared expert
|
64 |
+
fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
|
65 |
+
fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
|
66 |
+
remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
|
67 |
+
shared_expert_hidden_size: Optional[
|
68 |
+
int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
|
69 |
+
shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
|
70 |
+
|
71 |
+
# Router Z-loss arguments
|
72 |
+
moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
|
73 |
+
moe_zloss_in_fp32: bool = False
|
74 |
+
|
75 |
+
def __post_init__(self):
|
76 |
+
# Sparse MLP is not supported with triton >=3.2.0
|
77 |
+
# TODO: Remove this once sparse is supported with triton >=3.2.0
|
78 |
+
if self.__getattribute__('mlp_impl') == 'sparse':
|
79 |
+
try:
|
80 |
+
import triton
|
81 |
+
if triton.__version__ >= '3.2.0':
|
82 |
+
raise ValueError(
|
83 |
+
'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
|
84 |
+
)
|
85 |
+
except ImportError:
|
86 |
+
raise ImportError('Triton is required for sparse MLP implementation')
|
87 |
+
|
88 |
+
if self.__getattribute__('mlp_impl') == 'grouped':
|
89 |
+
grouped_gemm.assert_grouped_gemm_is_available()
|
90 |
+
|
91 |
+
if self.shared_expert_hidden_size is None:
|
92 |
+
self.shared_expert_hidden_size = self.ffn_hidden_size
|
93 |
+
|
94 |
+
|
95 |
+
def from_megatron(megatron_args: Any):
|
96 |
+
args = Arguments()
|
97 |
+
for field in dataclasses.fields(args):
|
98 |
+
if hasattr(megatron_args, field.name):
|
99 |
+
setattr(args, field.name, getattr(megatron_args, field.name))
|
100 |
+
return args
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from megablocks.layers.arguments import Arguments
|
7 |
+
|
8 |
+
|
9 |
+
def dtype(args: Arguments):
|
10 |
+
if args.fp16:
|
11 |
+
return torch.float16
|
12 |
+
elif args.bf16:
|
13 |
+
return torch.bfloat16
|
14 |
+
return None
|
15 |
+
|
16 |
+
|
17 |
+
def cast_if_autocast_enabled(tensor):
|
18 |
+
if torch.is_autocast_enabled():
|
19 |
+
if tensor.device.type == 'cuda':
|
20 |
+
dtype = torch.get_autocast_gpu_dtype()
|
21 |
+
elif tensor.device.type == 'cpu':
|
22 |
+
dtype = torch.get_autocast_cpu_dtype()
|
23 |
+
else:
|
24 |
+
raise NotImplementedError()
|
25 |
+
return tensor.to(dtype=dtype)
|
26 |
+
return tensor
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
from megablocks.layers import glu, mlp
|
7 |
+
from megablocks.layers.arguments import Arguments
|
8 |
+
|
9 |
+
MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
|
10 |
+
|
11 |
+
_REGISTRY = {
|
12 |
+
'mlp': {
|
13 |
+
'grouped': mlp.GroupedMLP,
|
14 |
+
'sparse': mlp.SparseMLP,
|
15 |
+
},
|
16 |
+
'glu': {
|
17 |
+
'grouped': glu.GroupedGLU,
|
18 |
+
'sparse': glu.SparseGLU,
|
19 |
+
},
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
def get(args: Arguments) -> MlpType:
|
24 |
+
"""Returns an MLP for use in a dMoE instance.
|
25 |
+
|
26 |
+
Uses the provided arguments to instantiate the appropriate
|
27 |
+
MLP instance. This only contains MLPs for use in dMoEs
|
28 |
+
(ie. only for the dropless versions of MoEs).
|
29 |
+
|
30 |
+
Args:
|
31 |
+
args: propagated Arguments dataclass.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
An instantiated MLP constructed using the input args.
|
35 |
+
"""
|
36 |
+
if args.mlp_type not in _REGISTRY:
|
37 |
+
raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
|
38 |
+
|
39 |
+
if args.mlp_impl not in _REGISTRY[args.mlp_type]:
|
40 |
+
raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
|
41 |
+
|
42 |
+
return _REGISTRY[args.mlp_type][args.mlp_impl](args)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import stk.ops
|
6 |
+
import torch
|
7 |
+
from stk import Matrix
|
8 |
+
|
9 |
+
import megablocks.ops as ops
|
10 |
+
# from megablocks.ops import ops
|
11 |
+
from megablocks.layers import common, dmlp_registry, moe, mpu
|
12 |
+
from megablocks.layers.arguments import Arguments
|
13 |
+
|
14 |
+
|
15 |
+
def promote_scalar(x):
|
16 |
+
return x.view(1) if not len(x.size()) else x
|
17 |
+
|
18 |
+
|
19 |
+
class ParallelDroplessMLP(moe.ParallelMLP):
|
20 |
+
|
21 |
+
def __init__(self, args: Arguments):
|
22 |
+
super(ParallelDroplessMLP, self).__init__(args)
|
23 |
+
self.hidden_size = args.hidden_size
|
24 |
+
self.ffn_hidden_size = mpu.features_per_rank(args)
|
25 |
+
self.blocking = 128
|
26 |
+
self.mlp = dmlp_registry.get(args)
|
27 |
+
|
28 |
+
# Calculate the number of bits needed to represent the column indices
|
29 |
+
# in the intermediate sparse matrix.
|
30 |
+
max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
|
31 |
+
self.transpose_sort_end_bit = max(
|
32 |
+
int(np.ceil(np.log2(max_column_index))),
|
33 |
+
1,
|
34 |
+
)
|
35 |
+
|
36 |
+
def sparse_transpose(self, size, row_indices, column_indices, offsets):
|
37 |
+
block_columns = size[1] // self.blocking
|
38 |
+
|
39 |
+
# Sort row indices by column indices to get the transposed matrix's
|
40 |
+
# column indices.
|
41 |
+
#
|
42 |
+
# NOTE: Our sort operation uses the same width indices as the input values.
|
43 |
+
# To avoid overflow when we have large activation matrices we cast to
|
44 |
+
# 32-bit before sorting.
|
45 |
+
_, gather_indices = ops.sort(
|
46 |
+
column_indices.int(),
|
47 |
+
self.transpose_sort_end_bit,
|
48 |
+
)
|
49 |
+
|
50 |
+
# There are a constant number of blocks in every row of the sparse matrix.
|
51 |
+
# A blocks offset is:
|
52 |
+
#
|
53 |
+
# row_index * blocks_per_row + column_index % blocks_per_row
|
54 |
+
#
|
55 |
+
# Once we have the block offsets ordered for transposition we can divide
|
56 |
+
# by blocks_per_row to get the transposed column indices.
|
57 |
+
column_indices_t = row_indices.gather(0, gather_indices.long())
|
58 |
+
block_offsets_t = gather_indices.int()
|
59 |
+
|
60 |
+
zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
|
61 |
+
nnz_per_column = ops.histogram(column_indices, block_columns)
|
62 |
+
nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
|
63 |
+
if nnz_per_column.dim() == 0:
|
64 |
+
# This addresses an edge case when ffn_hidden_size is equal to self.blocking.
|
65 |
+
nnz_per_column = nnz_per_column.unsqueeze(0)
|
66 |
+
offsets_t = torch.cat([zero, nnz_per_column])
|
67 |
+
return column_indices_t, offsets_t, block_offsets_t
|
68 |
+
|
69 |
+
def topology(self, x, padded_bins):
|
70 |
+
padded_tokens, _ = x.size()
|
71 |
+
assert padded_tokens % self.blocking == 0
|
72 |
+
if self.ffn_hidden_size % self.blocking != 0:
|
73 |
+
raise ValueError(
|
74 |
+
f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
|
75 |
+
f'the block size {self.blocking}. Please update your configuration.',
|
76 |
+
)
|
77 |
+
|
78 |
+
# Offsets for the sparse matrix. All rows have the
|
79 |
+
# same number of nonzero blocks dictated by the
|
80 |
+
# dimensionality of a single expert.
|
81 |
+
block_rows = padded_tokens // self.blocking
|
82 |
+
blocks_per_row = self.ffn_hidden_size // self.blocking
|
83 |
+
offsets = torch.arange(
|
84 |
+
0,
|
85 |
+
block_rows * blocks_per_row + 1,
|
86 |
+
blocks_per_row,
|
87 |
+
dtype=torch.int32,
|
88 |
+
device=x.device,
|
89 |
+
)
|
90 |
+
|
91 |
+
# Indices for the sparse matrix. The indices for
|
92 |
+
# the intermediate matrix are dynamic depending
|
93 |
+
# on the mapping of tokens to experts.
|
94 |
+
column_indices = ops.topology(
|
95 |
+
padded_bins,
|
96 |
+
self.blocking,
|
97 |
+
block_rows,
|
98 |
+
blocks_per_row,
|
99 |
+
)
|
100 |
+
|
101 |
+
# TODO(tgale): This is unused. Remove the need for this in stk.
|
102 |
+
# For now, use meta init to save the device memory.
|
103 |
+
data = torch.empty(
|
104 |
+
column_indices.numel(),
|
105 |
+
self.blocking,
|
106 |
+
self.blocking,
|
107 |
+
dtype=common.dtype(self.args),
|
108 |
+
device='meta',
|
109 |
+
)
|
110 |
+
shape = (
|
111 |
+
padded_tokens,
|
112 |
+
self.ffn_hidden_size * mpu.experts_per_rank(self.args),
|
113 |
+
)
|
114 |
+
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
115 |
+
column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
|
116 |
+
shape,
|
117 |
+
row_indices,
|
118 |
+
column_indices,
|
119 |
+
offsets,
|
120 |
+
)
|
121 |
+
return stk.Matrix(
|
122 |
+
shape,
|
123 |
+
data,
|
124 |
+
row_indices,
|
125 |
+
column_indices,
|
126 |
+
offsets,
|
127 |
+
column_indices_t,
|
128 |
+
offsets_t,
|
129 |
+
block_offsets_t,
|
130 |
+
)
|
131 |
+
|
132 |
+
def indices_and_padded_bins(self, top_experts):
|
133 |
+
# Sort the expert ids to produce the scatter/gather
|
134 |
+
# indices for the permutation.
|
135 |
+
top_experts = top_experts.int()
|
136 |
+
bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
|
137 |
+
|
138 |
+
# Histogram the expert ids to identify the number of
|
139 |
+
# tokens routed to each expert.
|
140 |
+
tokens_per_expert = ops.histogram(top_experts, self.num_experts)
|
141 |
+
|
142 |
+
# Round the token counts up to the block size used in
|
143 |
+
# the matrix muliplications. Caculate the starting
|
144 |
+
# position of each bin.
|
145 |
+
padded_tokens_per_expert = ops.round_up(
|
146 |
+
tokens_per_expert,
|
147 |
+
self.blocking,
|
148 |
+
)
|
149 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
150 |
+
padded_bins = promote_scalar(padded_bins)
|
151 |
+
|
152 |
+
# Calculate the bin bounds for the sorted tokens.
|
153 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
154 |
+
bins = promote_scalar(bins)
|
155 |
+
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
156 |
+
|
157 |
+
def sparse_forward_once(self, x, expert_weights, top_experts):
|
158 |
+
# x: [sl, bs, hs]
|
159 |
+
# expert_weights: [sl * bs, top-k]
|
160 |
+
# top_experts: [sl * bs, top-k]
|
161 |
+
expert_weights = expert_weights.flatten()
|
162 |
+
top_experts = top_experts.flatten()
|
163 |
+
with torch.no_grad():
|
164 |
+
indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
|
165 |
+
|
166 |
+
# Route the tokens for MoE computation.
|
167 |
+
x = x.view(-1, x.shape[-1])
|
168 |
+
x = ops.padded_gather(
|
169 |
+
x,
|
170 |
+
indices,
|
171 |
+
bin_ids,
|
172 |
+
bins,
|
173 |
+
padded_bins,
|
174 |
+
self.top_k,
|
175 |
+
)
|
176 |
+
|
177 |
+
# Create the sparse matrix topology.
|
178 |
+
with torch.no_grad():
|
179 |
+
topo = self.topology(x, padded_bins)
|
180 |
+
|
181 |
+
# Perform the expert computation.
|
182 |
+
x = self.mlp(x, topo)
|
183 |
+
|
184 |
+
# Un-route the data for the MoE output.
|
185 |
+
x = ops.padded_scatter(
|
186 |
+
x,
|
187 |
+
indices,
|
188 |
+
bin_ids,
|
189 |
+
expert_weights,
|
190 |
+
bins,
|
191 |
+
padded_bins,
|
192 |
+
self.top_k,
|
193 |
+
)
|
194 |
+
return x, tokens_per_expert
|
195 |
+
|
196 |
+
# For use in the base-class parallel_forward_once.
|
197 |
+
def sparse_permute_and_compute(
|
198 |
+
self,
|
199 |
+
x,
|
200 |
+
tokens_per_expert,
|
201 |
+
indices,
|
202 |
+
bin_ids,
|
203 |
+
expert_weights,
|
204 |
+
bins,
|
205 |
+
expert_capactiy, # unused
|
206 |
+
top_k,
|
207 |
+
):
|
208 |
+
|
209 |
+
# Round the token counts up to the block size used in the matrix
|
210 |
+
# multiplication. Calculate the starting position of each bin.
|
211 |
+
padded_tokens_per_expert = ops.round_up(
|
212 |
+
tokens_per_expert,
|
213 |
+
self.blocking,
|
214 |
+
)
|
215 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
216 |
+
padded_bins = promote_scalar(padded_bins)
|
217 |
+
|
218 |
+
# Route the tokens for MoE computation.
|
219 |
+
x = x.view(-1, x.shape[-1])
|
220 |
+
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
|
221 |
+
|
222 |
+
# Create the sparse matrix topology.
|
223 |
+
with torch.no_grad():
|
224 |
+
topo = self.topology(x, padded_bins)
|
225 |
+
|
226 |
+
# Perform the expert computation.
|
227 |
+
x = self.mlp(x, topo)
|
228 |
+
|
229 |
+
# Un-route the data for the MoE output.
|
230 |
+
return ops.padded_scatter(
|
231 |
+
x,
|
232 |
+
indices,
|
233 |
+
bin_ids,
|
234 |
+
expert_weights,
|
235 |
+
bins,
|
236 |
+
padded_bins,
|
237 |
+
top_k,
|
238 |
+
)
|
239 |
+
|
240 |
+
def grouped_forward_once(self, x, expert_weights, top_experts):
|
241 |
+
# x: [sl, bs, hs]
|
242 |
+
# expert_weights: [sl * bs, top-k]
|
243 |
+
# top_experts: [sl * bs, top-k]
|
244 |
+
expert_weights = expert_weights.flatten()
|
245 |
+
top_experts = top_experts.flatten()
|
246 |
+
with torch.no_grad():
|
247 |
+
indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
|
248 |
+
|
249 |
+
out = self.grouped_permute_and_compute(
|
250 |
+
x,
|
251 |
+
tokens_per_expert,
|
252 |
+
indices,
|
253 |
+
bin_ids,
|
254 |
+
expert_weights,
|
255 |
+
bins,
|
256 |
+
-1, # unused
|
257 |
+
self.args.moe_top_k,
|
258 |
+
)
|
259 |
+
return out, tokens_per_expert
|
260 |
+
|
261 |
+
def grouped_permute_and_compute(
|
262 |
+
self,
|
263 |
+
x,
|
264 |
+
tokens_per_expert,
|
265 |
+
indices,
|
266 |
+
bin_ids,
|
267 |
+
expert_weights,
|
268 |
+
bins,
|
269 |
+
expert_capactiy, # unused
|
270 |
+
top_k,
|
271 |
+
):
|
272 |
+
|
273 |
+
# Route the tokens for MoE computation.
|
274 |
+
x = x.view(-1, x.shape[-1])
|
275 |
+
x = ops.gather(x, indices, bin_ids, bins, top_k)
|
276 |
+
|
277 |
+
# Perform the expert computation.
|
278 |
+
x = self.mlp(x, tokens_per_expert)
|
279 |
+
|
280 |
+
# Un-route the data for the MoE output.
|
281 |
+
return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
|
282 |
+
|
283 |
+
def forward_once(self, x, expert_weights, top_experts):
|
284 |
+
if self.args.mlp_impl == 'sparse':
|
285 |
+
return self.sparse_forward_once(x, expert_weights, top_experts)
|
286 |
+
else:
|
287 |
+
return self.grouped_forward_once(x, expert_weights, top_experts)
|
288 |
+
|
289 |
+
def permute_and_compute(
|
290 |
+
self,
|
291 |
+
x,
|
292 |
+
tokens_per_expert,
|
293 |
+
indices,
|
294 |
+
bin_ids,
|
295 |
+
expert_weights,
|
296 |
+
bins,
|
297 |
+
expert_capactiy,
|
298 |
+
top_k,
|
299 |
+
):
|
300 |
+
if self.args.mlp_impl == 'sparse':
|
301 |
+
return self.sparse_permute_and_compute(
|
302 |
+
x,
|
303 |
+
tokens_per_expert,
|
304 |
+
indices,
|
305 |
+
bin_ids,
|
306 |
+
expert_weights,
|
307 |
+
bins,
|
308 |
+
expert_capactiy,
|
309 |
+
top_k,
|
310 |
+
)
|
311 |
+
else:
|
312 |
+
return self.grouped_permute_and_compute(
|
313 |
+
x,
|
314 |
+
tokens_per_expert,
|
315 |
+
indices,
|
316 |
+
bin_ids,
|
317 |
+
expert_weights,
|
318 |
+
bins,
|
319 |
+
expert_capactiy,
|
320 |
+
top_k,
|
321 |
+
)
|
322 |
+
|
323 |
+
|
324 |
+
class dMoE(moe.MoE):
|
325 |
+
|
326 |
+
def _init_experts_mlp(self, args: Arguments):
|
327 |
+
return ParallelDroplessMLP(args)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import stk
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
@torch.jit.script
|
10 |
+
def _gelu_backward_inplace(g, x):
|
11 |
+
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
12 |
+
ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
|
13 |
+
return g.mul_(ff)
|
14 |
+
|
15 |
+
|
16 |
+
def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
|
17 |
+
# NOTE: The two sparse matrices must have the same topology.
|
18 |
+
if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
|
19 |
+
return stk.Matrix(
|
20 |
+
x.size(),
|
21 |
+
_gelu_backward_inplace(grad.data, x.data),
|
22 |
+
x.row_indices,
|
23 |
+
x.column_indices,
|
24 |
+
x.offsets,
|
25 |
+
x.column_indices_t,
|
26 |
+
x.offsets_t,
|
27 |
+
x.block_offsets_t,
|
28 |
+
)
|
29 |
+
return _gelu_backward_inplace(grad, x)
|
30 |
+
|
31 |
+
|
32 |
+
def gelu(x: stk.Matrix):
|
33 |
+
assert isinstance(x, stk.Matrix)
|
34 |
+
return stk.Matrix(
|
35 |
+
x.size(),
|
36 |
+
F.gelu(x.data, approximate='tanh'),
|
37 |
+
x.row_indices,
|
38 |
+
x.column_indices,
|
39 |
+
x.offsets,
|
40 |
+
x.column_indices_t,
|
41 |
+
x.offsets_t,
|
42 |
+
x.block_offsets_t,
|
43 |
+
)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import stk.ops
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from megablocks import grouped_gemm_util as gg
|
8 |
+
from megablocks.layers import common, mpu
|
9 |
+
from megablocks.layers.activation_fn import act_fn
|
10 |
+
from megablocks.layers.arguments import Arguments
|
11 |
+
from megablocks.layers.mlp import (
|
12 |
+
SharedMLP,
|
13 |
+
SparseMLP,
|
14 |
+
create_dmoe_expert_weights,
|
15 |
+
resolve_dtensor,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class SparseGLU(SparseMLP):
|
20 |
+
|
21 |
+
def __init__(self, args: Arguments):
|
22 |
+
super().__init__(args)
|
23 |
+
self.v1 = torch.nn.Parameter(
|
24 |
+
torch.empty(
|
25 |
+
self._num_rows_per_rank,
|
26 |
+
args.hidden_size,
|
27 |
+
device=args.device,
|
28 |
+
dtype=common.dtype(args),
|
29 |
+
),
|
30 |
+
)
|
31 |
+
with torch.no_grad():
|
32 |
+
self.v1.copy_(
|
33 |
+
create_dmoe_expert_weights(
|
34 |
+
args,
|
35 |
+
args.moe_num_experts,
|
36 |
+
args.ffn_hidden_size,
|
37 |
+
args.hidden_size,
|
38 |
+
args.init_method,
|
39 |
+
),
|
40 |
+
)
|
41 |
+
|
42 |
+
mpu.set_expert_model_parallel_attributes(
|
43 |
+
self.v1,
|
44 |
+
self._should_set_parallelism_attribute,
|
45 |
+
)
|
46 |
+
|
47 |
+
def forward(self, x, topo):
|
48 |
+
if self.args.memory_optimized_mlp:
|
49 |
+
raise NotImplementedError(
|
50 |
+
'Memory optimized implementation not yet supported with GLU with sparse kernels.',
|
51 |
+
)
|
52 |
+
|
53 |
+
w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
|
54 |
+
w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
|
55 |
+
|
56 |
+
# Compute the GLU.
|
57 |
+
x1 = stk.ops.sdd(x, w1.t(), topo)
|
58 |
+
x2 = stk.ops.sdd(x, v1.t(), topo)
|
59 |
+
|
60 |
+
activation_fn_out = act_fn(x1, self.args.activation_fn)
|
61 |
+
x1 = stk.ops.mul(activation_fn_out, x2)
|
62 |
+
|
63 |
+
return stk.ops.dsd(x1, w2)
|
64 |
+
|
65 |
+
|
66 |
+
class MemoryOptimizedGroupedGLU(torch.autograd.Function):
|
67 |
+
"""GroupedMLP with manually scheduled memory reuse."""
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
71 |
+
def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
|
72 |
+
# Cast inputs using ctx dtype from AMP
|
73 |
+
if ctx._fwd_used_autocast:
|
74 |
+
x = x.to(ctx._dtype)
|
75 |
+
w1 = w1.to(ctx._dtype)
|
76 |
+
v1 = v1.to(ctx._dtype)
|
77 |
+
w2 = w2.to(ctx._dtype)
|
78 |
+
# x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
|
79 |
+
if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
|
80 |
+
raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
|
81 |
+
|
82 |
+
# Layer 0: x @ w1.t().
|
83 |
+
assert gg.backend is not None
|
84 |
+
sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
|
85 |
+
v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
|
86 |
+
|
87 |
+
# GeLU.
|
88 |
+
activation_fn_out = activation_fn(sdd_out) * v1_out
|
89 |
+
|
90 |
+
# Layer 1: x @ w2.
|
91 |
+
dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
|
92 |
+
|
93 |
+
# NOTE: Save the input to the layer and the activation_fn input for
|
94 |
+
# gradient computation. We'll re-compute the activation_fn forward
|
95 |
+
# pass in the backward pass to avoid materializing another
|
96 |
+
# intermediate.
|
97 |
+
ctx.x_shape = x.shape
|
98 |
+
ctx.sdd_out_shape = sdd_out.shape
|
99 |
+
ctx.dtype = x.dtype
|
100 |
+
ctx.activation_fn = activation_fn
|
101 |
+
ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
|
102 |
+
return dsd_out
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
106 |
+
def backward(ctx, ddsd_out):
|
107 |
+
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
|
108 |
+
raise ValueError('Expected all MLP inputs to need grad.')
|
109 |
+
|
110 |
+
# Unpack saved tensors
|
111 |
+
# dtype = ctx.dtype
|
112 |
+
saved_tensors = ctx.saved_tensors
|
113 |
+
w1, v1, w2 = saved_tensors[:3]
|
114 |
+
batch_sizes = saved_tensors[3]
|
115 |
+
x = saved_tensors[4]
|
116 |
+
sdd_out, v1_out = saved_tensors[5:7]
|
117 |
+
|
118 |
+
# Rematerialize activation_fn output.
|
119 |
+
activation_fn = ctx.activation_fn
|
120 |
+
with torch.set_grad_enabled(True):
|
121 |
+
sdd_out.requires_grad = True
|
122 |
+
v1_out.requires_grad = True
|
123 |
+
activation_fn_out = activation_fn(sdd_out) * v1_out
|
124 |
+
activation_grad_fn = activation_fn_out.backward
|
125 |
+
|
126 |
+
# Compute dw2 with recomputed activation_fn output.
|
127 |
+
assert gg.backend is not None
|
128 |
+
dw2 = gg.backend.gmm(
|
129 |
+
activation_fn_out,
|
130 |
+
ddsd_out,
|
131 |
+
batch_sizes,
|
132 |
+
trans_a=True,
|
133 |
+
)
|
134 |
+
|
135 |
+
# Compute dactivation_fn_out.
|
136 |
+
#
|
137 |
+
# NOTE: We reuse the activation_fn_out allocation.
|
138 |
+
dactivation_fn_out = activation_fn_out
|
139 |
+
gg.backend.gmm(
|
140 |
+
ddsd_out,
|
141 |
+
w2,
|
142 |
+
batch_sizes,
|
143 |
+
trans_b=True,
|
144 |
+
c=dactivation_fn_out,
|
145 |
+
)
|
146 |
+
|
147 |
+
# Compute dsdd_out.
|
148 |
+
#
|
149 |
+
# NOTE: This reuses the dactivation_fn_out allocation.
|
150 |
+
assert activation_grad_fn is not None
|
151 |
+
activation_grad_fn(dactivation_fn_out)
|
152 |
+
dsdd_out = sdd_out.grad
|
153 |
+
dv1_out = v1_out.grad
|
154 |
+
|
155 |
+
# Compute dw1.
|
156 |
+
dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
|
157 |
+
|
158 |
+
# Compute dv1.
|
159 |
+
dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
|
160 |
+
|
161 |
+
# Compute dx.
|
162 |
+
#
|
163 |
+
# NOTE: This reuses the ddsd_out allocation.
|
164 |
+
dx = ddsd_out
|
165 |
+
gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
|
166 |
+
dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
|
167 |
+
return dx, dw1, dv1, dw2, None, None
|
168 |
+
|
169 |
+
|
170 |
+
memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
|
171 |
+
|
172 |
+
|
173 |
+
class GroupedGLU(SparseGLU):
|
174 |
+
|
175 |
+
def forward(self, x, tokens_per_expert):
|
176 |
+
batch_sizes = tokens_per_expert.cpu().to(torch.long)
|
177 |
+
w1, v1, w2 = (
|
178 |
+
self.scale_grad(self.w1),
|
179 |
+
self.scale_grad(self.v1),
|
180 |
+
self.scale_grad(self.w2),
|
181 |
+
)
|
182 |
+
w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
|
183 |
+
|
184 |
+
# Re-shape the weights for the grouped GEMMs.
|
185 |
+
ne = mpu.experts_per_rank(self.args)
|
186 |
+
w1 = w1.view(ne, -1, self.args.hidden_size)
|
187 |
+
v1 = v1.view(ne, -1, self.args.hidden_size)
|
188 |
+
w2 = w2.view(ne, -1, self.args.hidden_size)
|
189 |
+
|
190 |
+
if self.args.memory_optimized_mlp:
|
191 |
+
return memory_optimized_grouped_glu(
|
192 |
+
x,
|
193 |
+
w1,
|
194 |
+
v1,
|
195 |
+
w2,
|
196 |
+
batch_sizes,
|
197 |
+
self.args.activation_fn,
|
198 |
+
)
|
199 |
+
|
200 |
+
# Compute the MLP.
|
201 |
+
assert gg.ops is not None
|
202 |
+
x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
|
203 |
+
x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
|
204 |
+
x1 = self.args.activation_fn(x1) * x2
|
205 |
+
return gg.ops.gmm(x1, w2, batch_sizes)
|
206 |
+
|
207 |
+
|
208 |
+
class SharedGLU(SharedMLP):
|
209 |
+
"""GPU for shared expert.
|
210 |
+
|
211 |
+
Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
|
212 |
+
"""
|
213 |
+
|
214 |
+
def __init__(self, args: Arguments):
|
215 |
+
super().__init__(args)
|
216 |
+
self.gate_proj = args.fc_cls(
|
217 |
+
args.hidden_size,
|
218 |
+
self.args.shared_expert_hidden_size,
|
219 |
+
**self.fc_kwargs,
|
220 |
+
)
|
221 |
+
|
222 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
223 |
+
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import gc
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
from megablocks.layers import arguments, dmoe
|
10 |
+
|
11 |
+
_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
|
12 |
+
|
13 |
+
|
14 |
+
def get_tensors():
|
15 |
+
ptrs = set()
|
16 |
+
out = []
|
17 |
+
for obj in gc.get_objects():
|
18 |
+
if torch.is_tensor(obj):
|
19 |
+
if not obj.is_contiguous() or obj.data_ptr() in ptrs:
|
20 |
+
continue
|
21 |
+
out.append(obj)
|
22 |
+
ptrs.add(obj.data_ptr())
|
23 |
+
return out
|
24 |
+
|
25 |
+
|
26 |
+
def test_memory(
|
27 |
+
group,
|
28 |
+
batch_size,
|
29 |
+
sequence_length,
|
30 |
+
hidden_size,
|
31 |
+
ffn_hidden_size,
|
32 |
+
num_experts,
|
33 |
+
top_k,
|
34 |
+
):
|
35 |
+
args = arguments.Arguments(
|
36 |
+
hidden_size=hidden_size,
|
37 |
+
ffn_hidden_size=ffn_hidden_size,
|
38 |
+
moe_num_experts=num_experts,
|
39 |
+
moe_top_k=top_k,
|
40 |
+
moe_expert_model_parallelism=True,
|
41 |
+
expert_parallel_group=group,
|
42 |
+
fp16=False,
|
43 |
+
bf16=True,
|
44 |
+
device=torch.cuda.current_device(),
|
45 |
+
)
|
46 |
+
layer = dmoe.dMoE(args).cuda()
|
47 |
+
|
48 |
+
x = torch.randn((batch_size, sequence_length, hidden_size),
|
49 |
+
device=torch.cuda.current_device(),
|
50 |
+
dtype=torch.bfloat16).requires_grad_(True)
|
51 |
+
torch.cuda.empty_cache()
|
52 |
+
|
53 |
+
# Run forward + backward.
|
54 |
+
# with torch.autograd.detect_anomaly():
|
55 |
+
out, _ = layer(x)
|
56 |
+
out.mean().backward()
|
57 |
+
|
58 |
+
# Report peak memory.
|
59 |
+
mem = torch.cuda.max_memory_allocated()
|
60 |
+
print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
|
61 |
+
print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
|
62 |
+
|
63 |
+
# Calculate weight and gradient memory usage.
|
64 |
+
weight_memory = 2 * (
|
65 |
+
layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
|
66 |
+
)
|
67 |
+
|
68 |
+
def grad_numel(x):
|
69 |
+
if x.grad is not None:
|
70 |
+
return x.grad.numel()
|
71 |
+
return 0
|
72 |
+
|
73 |
+
grad_memory = 2 * (
|
74 |
+
grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
|
75 |
+
)
|
76 |
+
weight_memory += grad_memory
|
77 |
+
|
78 |
+
print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
|
79 |
+
print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
|
80 |
+
|
81 |
+
# Manually calculate GPU memory usage from the garbage
|
82 |
+
# collector.
|
83 |
+
gc.collect()
|
84 |
+
total = 0
|
85 |
+
tensors = get_tensors()
|
86 |
+
tensors = sorted(tensors, key=lambda x: -x.numel())
|
87 |
+
for i, t in enumerate(tensors):
|
88 |
+
total += t.numel()
|
89 |
+
print(f'{i}: {t.shape}, {t.numel() * 2}')
|
90 |
+
del tensors
|
91 |
+
|
92 |
+
print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
assert dist.is_available()
|
97 |
+
group = dist.init_process_group(backend='nccl')
|
98 |
+
local_rank = dist.get_rank(group)
|
99 |
+
torch.cuda.set_device(local_rank)
|
100 |
+
|
101 |
+
for args in _TESTS:
|
102 |
+
test_memory(group, *args)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py
ADDED
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import stk
|
7 |
+
import stk.backend.triton_kernels
|
8 |
+
import stk.ops
|
9 |
+
import torch
|
10 |
+
from packaging import version
|
11 |
+
|
12 |
+
from megablocks import grouped_gemm_util as gg
|
13 |
+
from megablocks.layers import common, gelu, mpu
|
14 |
+
from megablocks.layers.activation_fn import act_fn
|
15 |
+
from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
|
16 |
+
|
17 |
+
|
18 |
+
class ScaleGradient(torch.autograd.Function):
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
22 |
+
def forward(ctx: Any, x: torch.Tensor, scale: float):
|
23 |
+
ctx.scale = scale
|
24 |
+
return x
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
28 |
+
def backward(ctx: torch.Tensor, grad: torch.Tensor):
|
29 |
+
return grad * ctx.scale, None
|
30 |
+
|
31 |
+
|
32 |
+
scale_gradient = ScaleGradient.apply
|
33 |
+
|
34 |
+
|
35 |
+
def resolve_dtensor(weight: torch.Tensor):
|
36 |
+
if version.parse(torch.__version__) >= version.parse('2.0.0'):
|
37 |
+
from torch.distributed._tensor import DTensor
|
38 |
+
if isinstance(weight, DTensor):
|
39 |
+
return weight.to_local()
|
40 |
+
return weight
|
41 |
+
|
42 |
+
|
43 |
+
def create_moe_expert_weights(
|
44 |
+
args: Arguments,
|
45 |
+
num_experts: int,
|
46 |
+
ffn_hidden_size: int,
|
47 |
+
hidden_size: int,
|
48 |
+
init_method: InitFn,
|
49 |
+
):
|
50 |
+
# Create the entire weight matrix such that the sampled weights will
|
51 |
+
# not vary between data parallelism and expert model parallelism for
|
52 |
+
# the same random seed.
|
53 |
+
master_weights = torch.empty(
|
54 |
+
num_experts,
|
55 |
+
ffn_hidden_size,
|
56 |
+
hidden_size,
|
57 |
+
device=args.device,
|
58 |
+
dtype=common.dtype(args),
|
59 |
+
)
|
60 |
+
init_method(master_weights)
|
61 |
+
|
62 |
+
if not args.moe_expert_model_parallelism:
|
63 |
+
return master_weights
|
64 |
+
|
65 |
+
# Calculate the amount of sharding in each dimension.
|
66 |
+
expert_sharding_degree = mpu.expert_sharding_degree(args)
|
67 |
+
hidden_sharding_degree = mpu.hidden_sharding_degree(args)
|
68 |
+
|
69 |
+
# Calculate the experts per rank.
|
70 |
+
#
|
71 |
+
# NOTE: We assign ranks to be expert parallel before going
|
72 |
+
# tensor parallel.
|
73 |
+
rank = mpu.get_expert_parallel_rank(args)
|
74 |
+
expert_rank = rank % expert_sharding_degree
|
75 |
+
num_experts_per_rank = num_experts // expert_sharding_degree
|
76 |
+
start_expert = expert_rank * num_experts_per_rank
|
77 |
+
end_expert = (expert_rank + 1) * num_experts_per_rank
|
78 |
+
|
79 |
+
# Calculate the rows per rank.
|
80 |
+
row_rank = rank // expert_sharding_degree
|
81 |
+
num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
|
82 |
+
start_row = row_rank * num_rows_per_rank
|
83 |
+
end_row = (row_rank + 1) * num_rows_per_rank
|
84 |
+
|
85 |
+
# Slice the weight matrix to get the chunk for this rank.
|
86 |
+
with torch.no_grad():
|
87 |
+
weights = master_weights[start_expert:end_expert, start_row:end_row]
|
88 |
+
return weights
|
89 |
+
|
90 |
+
|
91 |
+
class MLP(torch.nn.Module):
|
92 |
+
|
93 |
+
def __init__(self, args: Arguments):
|
94 |
+
super().__init__()
|
95 |
+
self.args = args
|
96 |
+
# expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
|
97 |
+
experts_per_rank = mpu.experts_per_rank(args)
|
98 |
+
|
99 |
+
self.w1 = torch.nn.Parameter(
|
100 |
+
torch.empty(
|
101 |
+
experts_per_rank,
|
102 |
+
args.hidden_size,
|
103 |
+
mpu.features_per_rank(args),
|
104 |
+
device=args.device,
|
105 |
+
dtype=common.dtype(args),
|
106 |
+
),
|
107 |
+
)
|
108 |
+
self.w2 = torch.nn.Parameter(
|
109 |
+
torch.empty(
|
110 |
+
experts_per_rank,
|
111 |
+
mpu.features_per_rank(args),
|
112 |
+
args.hidden_size,
|
113 |
+
device=args.device,
|
114 |
+
dtype=common.dtype(args),
|
115 |
+
),
|
116 |
+
)
|
117 |
+
mpu.set_expert_model_parallel_attributes(
|
118 |
+
self.w1,
|
119 |
+
args.moe_expert_model_parallelism,
|
120 |
+
)
|
121 |
+
mpu.set_expert_model_parallel_attributes(
|
122 |
+
self.w2,
|
123 |
+
args.moe_expert_model_parallelism,
|
124 |
+
)
|
125 |
+
|
126 |
+
# Initialize the parameters for the MLP.
|
127 |
+
#
|
128 |
+
# NOTE: It is important that we create the weight tensors prior
|
129 |
+
# to creating the master weights and slicing our the piece for
|
130 |
+
# this rank. If the master weights are created first the PyTorch
|
131 |
+
# caching allocator appears to use the same memory block for these
|
132 |
+
# and the slice which causes large increases in our peak memory
|
133 |
+
# usage.
|
134 |
+
with torch.no_grad():
|
135 |
+
w1 = create_moe_expert_weights(
|
136 |
+
args,
|
137 |
+
args.moe_num_experts,
|
138 |
+
args.ffn_hidden_size,
|
139 |
+
args.hidden_size,
|
140 |
+
args.init_method,
|
141 |
+
)
|
142 |
+
self.w1.copy_(w1.transpose(1, 2).contiguous())
|
143 |
+
self.w2.copy_(
|
144 |
+
create_moe_expert_weights(
|
145 |
+
args,
|
146 |
+
args.moe_num_experts,
|
147 |
+
args.ffn_hidden_size,
|
148 |
+
args.hidden_size,
|
149 |
+
args.output_layer_init_method,
|
150 |
+
),
|
151 |
+
)
|
152 |
+
|
153 |
+
self.gradient_scale = None
|
154 |
+
if self.args.moe_expert_model_parallelism:
|
155 |
+
self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
|
156 |
+
|
157 |
+
def scale_grad(self, w):
|
158 |
+
if self.gradient_scale is None:
|
159 |
+
return w
|
160 |
+
return scale_gradient(w, self.gradient_scale)
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
|
164 |
+
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
|
165 |
+
x = torch.bmm(x, w1)
|
166 |
+
x = self.args.activation_fn(x)
|
167 |
+
return torch.bmm(x, w2)
|
168 |
+
|
169 |
+
|
170 |
+
def create_dmoe_expert_weights(
|
171 |
+
args: Arguments,
|
172 |
+
num_experts: int,
|
173 |
+
rows: int,
|
174 |
+
columns: int,
|
175 |
+
init_method: InitFn,
|
176 |
+
):
|
177 |
+
weights = create_moe_expert_weights(
|
178 |
+
args,
|
179 |
+
num_experts,
|
180 |
+
rows,
|
181 |
+
columns,
|
182 |
+
init_method,
|
183 |
+
)
|
184 |
+
return weights.view([-1, columns])
|
185 |
+
|
186 |
+
|
187 |
+
class MemoryOptimizedMLP(torch.autograd.Function):
|
188 |
+
"""Sparse MLP with manually scheduled memory reuse."""
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
192 |
+
def forward(ctx, x, w1, w2, topo, activation_fn):
|
193 |
+
# Cast inputs using ctx dtype from AMP
|
194 |
+
if ctx._fwd_used_autocast:
|
195 |
+
x = x.to(ctx._dtype)
|
196 |
+
w1 = w1.to(ctx._dtype)
|
197 |
+
w2 = w2.to(ctx._dtype)
|
198 |
+
# x: [m, k], w1: [n, k], w2: [n, k]
|
199 |
+
if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
|
200 |
+
raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
|
201 |
+
|
202 |
+
topo_tensors = (
|
203 |
+
topo.row_indices,
|
204 |
+
topo.column_indices,
|
205 |
+
topo.offsets,
|
206 |
+
topo.column_indices_t,
|
207 |
+
topo.offsets_t,
|
208 |
+
topo.block_offsets_t,
|
209 |
+
)
|
210 |
+
|
211 |
+
# Layer 0: x @ w1.t().
|
212 |
+
sdd_out = stk.ops.sdd(x, w1.t(), topo)
|
213 |
+
|
214 |
+
# GeLU.
|
215 |
+
activation_fn_out = act_fn(sdd_out, activation_fn)
|
216 |
+
|
217 |
+
# Layer 1: x @ w2.
|
218 |
+
dsd_out = stk.ops.dsd(activation_fn_out, w2)
|
219 |
+
|
220 |
+
# NOTE: Save the input to the layer and the activation_fn input for
|
221 |
+
# gradient computation. We'll re-compute the activation_fn forward
|
222 |
+
# pass in the backward pass to avoid materializing another
|
223 |
+
# intermediate.
|
224 |
+
ctx.shape = topo.shape
|
225 |
+
ctx.x_shape = x.shape
|
226 |
+
ctx.sdd_out_shape = sdd_out.data.shape
|
227 |
+
ctx.dtype = x.dtype
|
228 |
+
ctx.activation_fn = activation_fn
|
229 |
+
ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
|
230 |
+
return dsd_out
|
231 |
+
|
232 |
+
@staticmethod
|
233 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
234 |
+
def backward(ctx, ddsd_out):
|
235 |
+
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
|
236 |
+
raise ValueError('Expected all MLP inputs to need grad.')
|
237 |
+
|
238 |
+
# unpack saved tensors
|
239 |
+
# dtype = ctx.dtype
|
240 |
+
saved_tensors = ctx.saved_tensors
|
241 |
+
w1, w2 = saved_tensors[:2]
|
242 |
+
topo_tensors = saved_tensors[2:8]
|
243 |
+
x = saved_tensors[8]
|
244 |
+
sdd_out_data = saved_tensors[9]
|
245 |
+
|
246 |
+
# rematerialize activation function output
|
247 |
+
activation_fn = ctx.activation_fn
|
248 |
+
sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
|
249 |
+
activation_fn_out, activation_grad_fn = act_fn(
|
250 |
+
sdd_out,
|
251 |
+
activation_fn,
|
252 |
+
return_grad_fn=True,
|
253 |
+
)
|
254 |
+
|
255 |
+
# Compute dw2 with recomputed activation_fn output.
|
256 |
+
dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
|
257 |
+
|
258 |
+
# Compute dactivation_fn_out.
|
259 |
+
#
|
260 |
+
# NOTE: We reuse the activation_fn_out allocation.
|
261 |
+
dactivation_fn_out = activation_fn_out
|
262 |
+
stk.backend.triton_kernels.sdd(
|
263 |
+
ddsd_out,
|
264 |
+
w2.t(),
|
265 |
+
dactivation_fn_out.shape,
|
266 |
+
dactivation_fn_out.data,
|
267 |
+
dactivation_fn_out.offsets,
|
268 |
+
dactivation_fn_out.row_indices,
|
269 |
+
dactivation_fn_out.column_indices,
|
270 |
+
)
|
271 |
+
|
272 |
+
# Compute dsdd_out.
|
273 |
+
#
|
274 |
+
# NOTE: This reuses the dactivation_fn_out allocation.
|
275 |
+
if activation_fn is DEFAULT_ACTIVATION_FN:
|
276 |
+
dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
|
277 |
+
else:
|
278 |
+
assert activation_grad_fn is not None
|
279 |
+
activation_grad_fn(dactivation_fn_out.data)
|
280 |
+
dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
|
281 |
+
|
282 |
+
# Compute dw1.
|
283 |
+
dw1 = stk.ops.dsd(dsdd_out.t(), x)
|
284 |
+
|
285 |
+
# Compute dx.
|
286 |
+
#
|
287 |
+
# NOTE: This reuses the ddsd_out allocation.
|
288 |
+
stk.backend.triton_kernels.dsd(
|
289 |
+
dsdd_out.shape,
|
290 |
+
dsdd_out.data,
|
291 |
+
dsdd_out.offsets,
|
292 |
+
dsdd_out.row_indices,
|
293 |
+
dsdd_out.column_indices,
|
294 |
+
dsdd_out.offsets_t,
|
295 |
+
dsdd_out.column_indices_t,
|
296 |
+
dsdd_out.block_offsets_t,
|
297 |
+
False,
|
298 |
+
w1,
|
299 |
+
ddsd_out,
|
300 |
+
)
|
301 |
+
dx = ddsd_out
|
302 |
+
return dx, dw1, dw2, None, None
|
303 |
+
|
304 |
+
|
305 |
+
memory_optimized_mlp = MemoryOptimizedMLP.apply
|
306 |
+
|
307 |
+
|
308 |
+
class SparseMLP(torch.nn.Module):
|
309 |
+
|
310 |
+
def __init__(self, args: Arguments):
|
311 |
+
super().__init__()
|
312 |
+
self.args = args
|
313 |
+
self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
|
314 |
+
|
315 |
+
self.w1 = torch.nn.Parameter(
|
316 |
+
torch.empty(
|
317 |
+
self._num_rows_per_rank,
|
318 |
+
args.hidden_size,
|
319 |
+
device=args.device,
|
320 |
+
dtype=common.dtype(args),
|
321 |
+
),
|
322 |
+
)
|
323 |
+
self.w2 = torch.nn.Parameter(
|
324 |
+
torch.empty(
|
325 |
+
self._num_rows_per_rank,
|
326 |
+
args.hidden_size,
|
327 |
+
device=args.device,
|
328 |
+
dtype=common.dtype(args),
|
329 |
+
),
|
330 |
+
)
|
331 |
+
|
332 |
+
# Initialize the parameters for the MLP.
|
333 |
+
#
|
334 |
+
# NOTE: It is important that we create the weight tensors prior
|
335 |
+
# to creating the master weights and slicing our the piece for
|
336 |
+
# this rank. If the master weights are created first the PyTorch
|
337 |
+
# caching allocator appears to use the same memory block for these
|
338 |
+
# and the slice which causes large increases in our peak memory
|
339 |
+
# usage.
|
340 |
+
with torch.no_grad():
|
341 |
+
self.w1.copy_(
|
342 |
+
create_dmoe_expert_weights(
|
343 |
+
args,
|
344 |
+
args.moe_num_experts,
|
345 |
+
args.ffn_hidden_size,
|
346 |
+
args.hidden_size,
|
347 |
+
args.init_method,
|
348 |
+
),
|
349 |
+
)
|
350 |
+
self.w2.copy_(
|
351 |
+
create_dmoe_expert_weights(
|
352 |
+
args,
|
353 |
+
args.moe_num_experts,
|
354 |
+
args.ffn_hidden_size,
|
355 |
+
args.hidden_size,
|
356 |
+
args.output_layer_init_method,
|
357 |
+
),
|
358 |
+
)
|
359 |
+
|
360 |
+
self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
|
361 |
+
mpu.set_expert_model_parallel_attributes(
|
362 |
+
self.w1,
|
363 |
+
self._should_set_parallelism_attribute,
|
364 |
+
)
|
365 |
+
mpu.set_expert_model_parallel_attributes(
|
366 |
+
self.w2,
|
367 |
+
self._should_set_parallelism_attribute,
|
368 |
+
)
|
369 |
+
|
370 |
+
self.gradient_scale = None
|
371 |
+
if self.args.moe_expert_model_parallelism:
|
372 |
+
self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
|
373 |
+
|
374 |
+
def scale_grad(self, w):
|
375 |
+
if self.gradient_scale is None:
|
376 |
+
return w
|
377 |
+
return scale_gradient(w, self.gradient_scale)
|
378 |
+
|
379 |
+
def forward(self, x, topo):
|
380 |
+
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
|
381 |
+
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
|
382 |
+
if self.args.memory_optimized_mlp:
|
383 |
+
return memory_optimized_mlp(
|
384 |
+
x,
|
385 |
+
w1,
|
386 |
+
w2,
|
387 |
+
topo,
|
388 |
+
self.args.activation_fn,
|
389 |
+
)
|
390 |
+
|
391 |
+
# Compute the MLP.
|
392 |
+
x = stk.ops.sdd(x, w1.t(), topo)
|
393 |
+
activation_fn_out = act_fn(x, self.args.activation_fn)
|
394 |
+
return stk.ops.dsd(activation_fn_out, w2)
|
395 |
+
|
396 |
+
|
397 |
+
class MemoryOptimizedGroupedMLP(torch.autograd.Function):
|
398 |
+
"""GroupedMLP with manually scheduled memory reuse."""
|
399 |
+
|
400 |
+
@staticmethod
|
401 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
402 |
+
def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
|
403 |
+
# Cast inputs using ctx dtype from AMP
|
404 |
+
if ctx._fwd_used_autocast:
|
405 |
+
x = x.to(ctx._dtype)
|
406 |
+
w1 = w1.to(ctx._dtype)
|
407 |
+
w2 = w2.to(ctx._dtype)
|
408 |
+
# x: [m, k], w1: [n, k], w2: [n, k]
|
409 |
+
if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
|
410 |
+
raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
|
411 |
+
|
412 |
+
# Layer 0: x @ w1.t().
|
413 |
+
assert gg.backend is not None
|
414 |
+
sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
|
415 |
+
|
416 |
+
# activation_fn
|
417 |
+
activation_fn_out = activation_fn(sdd_out)
|
418 |
+
|
419 |
+
# Layer 1: x @ w2.
|
420 |
+
dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
|
421 |
+
|
422 |
+
# NOTE: Save the input to the layer and the activation_fn input for
|
423 |
+
# gradient computation. We'll re-compute the activation_fn forward
|
424 |
+
# pass in the backward pass to avoid materializing another
|
425 |
+
# intermediate.
|
426 |
+
ctx.x_shape = x.shape
|
427 |
+
ctx.sdd_out_shape = sdd_out.shape
|
428 |
+
ctx.dtype = x.dtype
|
429 |
+
ctx.activation_fn = activation_fn
|
430 |
+
ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
|
431 |
+
return dsd_out
|
432 |
+
|
433 |
+
@staticmethod
|
434 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
435 |
+
def backward(ctx: Any, ddsd_out: torch.Tensor):
|
436 |
+
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
|
437 |
+
raise ValueError('Expected all MLP inputs to need grad.')
|
438 |
+
|
439 |
+
# Unpack saved tensors
|
440 |
+
# dtype = ctx.dtype
|
441 |
+
saved_tensors = ctx.saved_tensors
|
442 |
+
w1, w2 = saved_tensors[:2]
|
443 |
+
batch_sizes = saved_tensors[2]
|
444 |
+
x = saved_tensors[3]
|
445 |
+
sdd_out = saved_tensors[4]
|
446 |
+
|
447 |
+
# Rematerialize activation_fn output.
|
448 |
+
activation_fn = ctx.activation_fn
|
449 |
+
with torch.set_grad_enabled(True):
|
450 |
+
sdd_out.requires_grad = True
|
451 |
+
activation_fn_out = activation_fn(sdd_out)
|
452 |
+
activation_grad_fn = activation_fn_out.backward
|
453 |
+
|
454 |
+
# Compute dw2 with recomputed activation_fn output.
|
455 |
+
assert gg.backend is not None
|
456 |
+
dw2 = gg.backend.gmm(
|
457 |
+
activation_fn_out,
|
458 |
+
ddsd_out,
|
459 |
+
batch_sizes,
|
460 |
+
trans_a=True,
|
461 |
+
)
|
462 |
+
|
463 |
+
# Compute dactivation_fn_out.
|
464 |
+
#
|
465 |
+
# NOTE: We reuse the activation_fn_out allocation.
|
466 |
+
dactivation_fn_out = activation_fn_out
|
467 |
+
gg.backend.gmm(
|
468 |
+
ddsd_out,
|
469 |
+
w2,
|
470 |
+
batch_sizes,
|
471 |
+
trans_b=True,
|
472 |
+
c=dactivation_fn_out,
|
473 |
+
)
|
474 |
+
|
475 |
+
# Compute dsdd_out.
|
476 |
+
#
|
477 |
+
# NOTE: This reuses the dactivation_fn_out allocation.
|
478 |
+
if activation_fn is DEFAULT_ACTIVATION_FN:
|
479 |
+
dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
|
480 |
+
else:
|
481 |
+
assert activation_grad_fn is not None
|
482 |
+
activation_grad_fn(dactivation_fn_out)
|
483 |
+
dsdd_out = sdd_out.grad
|
484 |
+
|
485 |
+
# Compute dw1.
|
486 |
+
dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
|
487 |
+
|
488 |
+
# Compute dx.
|
489 |
+
#
|
490 |
+
# NOTE: This reuses the ddsd_out allocation.
|
491 |
+
gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
|
492 |
+
dx = ddsd_out
|
493 |
+
return dx, dw1, dw2, None, None
|
494 |
+
|
495 |
+
|
496 |
+
memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
|
497 |
+
|
498 |
+
|
499 |
+
class GroupedMLP(SparseMLP):
|
500 |
+
|
501 |
+
def forward(self, x, tokens_per_expert):
|
502 |
+
batch_sizes = tokens_per_expert.cpu().to(torch.long)
|
503 |
+
w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
|
504 |
+
|
505 |
+
# Re-shape the weights for the grouped GEMMs.
|
506 |
+
ne = mpu.experts_per_rank(self.args)
|
507 |
+
w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
|
508 |
+
w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
|
509 |
+
|
510 |
+
if self.args.memory_optimized_mlp:
|
511 |
+
return memory_optimized_grouped_mlp(
|
512 |
+
x,
|
513 |
+
w1,
|
514 |
+
w2,
|
515 |
+
batch_sizes,
|
516 |
+
self.args.activation_fn,
|
517 |
+
)
|
518 |
+
|
519 |
+
# Compute the MLP.
|
520 |
+
assert gg.ops is not None
|
521 |
+
x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
|
522 |
+
x = self.args.activation_fn(x)
|
523 |
+
return gg.ops.gmm(x, w2, batch_sizes)
|
524 |
+
|
525 |
+
|
526 |
+
class SharedMLP(torch.nn.Module):
|
527 |
+
"""MLP for shared expert.
|
528 |
+
|
529 |
+
Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
|
530 |
+
"""
|
531 |
+
|
532 |
+
def __init__(self, args: Arguments):
|
533 |
+
super().__init__()
|
534 |
+
self.args = args
|
535 |
+
self.fc_kwargs: dict[str, Any] = {
|
536 |
+
'bias': args.bias,
|
537 |
+
'device': args.device,
|
538 |
+
}
|
539 |
+
self.fc_kwargs.update(args.fc_kwargs)
|
540 |
+
|
541 |
+
self.up_proj = args.fc_cls(
|
542 |
+
args.hidden_size,
|
543 |
+
args.shared_expert_hidden_size,
|
544 |
+
**self.fc_kwargs,
|
545 |
+
)
|
546 |
+
self.act = args.activation_fn
|
547 |
+
self.down_proj = args.fc_cls(
|
548 |
+
args.shared_expert_hidden_size,
|
549 |
+
args.hidden_size,
|
550 |
+
**self.fc_kwargs,
|
551 |
+
)
|
552 |
+
self.down_proj._is_residual = True # a flag for llm-foundry init
|
553 |
+
|
554 |
+
def add_experts_sharedexpert(
|
555 |
+
self,
|
556 |
+
shared_expert_out: torch.Tensor,
|
557 |
+
expert_out: torch.Tensor,
|
558 |
+
) -> torch.Tensor:
|
559 |
+
# Helper function to add expert output to shared expert output
|
560 |
+
# with optional weighted sum.
|
561 |
+
if self.args.shared_expert_weighted_sum:
|
562 |
+
# enable using weighted sum for shared expert output
|
563 |
+
# wieghted by number of experts used
|
564 |
+
t_experts = self.args.moe_top_k + 1
|
565 |
+
sh_mlp_out = shared_expert_out / t_experts
|
566 |
+
return sh_mlp_out.add(
|
567 |
+
expert_out,
|
568 |
+
alpha=(self.args.moe_top_k / t_experts),
|
569 |
+
)
|
570 |
+
|
571 |
+
return shared_expert_out + expert_out
|
572 |
+
|
573 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
574 |
+
return self.down_proj(self.act(self.up_proj(x)))
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
import megablocks.ops as ops
|
10 |
+
from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
|
11 |
+
from megablocks.layers.all_to_all import all_to_all
|
12 |
+
from megablocks.layers.arguments import Arguments
|
13 |
+
|
14 |
+
_LOAD_BALANCING_LOSS = []
|
15 |
+
|
16 |
+
|
17 |
+
def save_load_balancing_loss(loss):
|
18 |
+
global _LOAD_BALANCING_LOSS
|
19 |
+
_LOAD_BALANCING_LOSS.append(loss)
|
20 |
+
|
21 |
+
|
22 |
+
def get_load_balancing_loss():
|
23 |
+
global _LOAD_BALANCING_LOSS
|
24 |
+
return _LOAD_BALANCING_LOSS
|
25 |
+
|
26 |
+
|
27 |
+
def clear_load_balancing_loss():
|
28 |
+
global _LOAD_BALANCING_LOSS
|
29 |
+
_LOAD_BALANCING_LOSS.clear()
|
30 |
+
|
31 |
+
|
32 |
+
def batched_load_balancing_loss(args: Arguments):
|
33 |
+
if args.moe_loss_weight == 0:
|
34 |
+
return 0.0
|
35 |
+
|
36 |
+
# tokens_per_expert[i].shape = (num_experts)
|
37 |
+
# expert_scores[i].shape = (tokens, num_experts)
|
38 |
+
tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
|
39 |
+
num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
|
40 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
41 |
+
num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
|
42 |
+
|
43 |
+
if len(tokens_per_expert) != num_layers_per_pipeline_stage:
|
44 |
+
raise ValueError(
|
45 |
+
f'Expected {num_layers_per_pipeline_stage} token_per_experts '
|
46 |
+
f'but found {len(tokens_per_expert)}.\nnum_layers = '
|
47 |
+
f'{args.num_layers}\npipeline_model_parallel_size = '
|
48 |
+
f'{args.pipeline_model_parallel_size}\n'
|
49 |
+
'num_layers_per_virtual_pipeline_stage'
|
50 |
+
f' = {args.num_layers_per_virtual_pipeline_stage}',
|
51 |
+
)
|
52 |
+
if len(expert_scores) != num_layers_per_pipeline_stage:
|
53 |
+
raise ValueError(
|
54 |
+
f'Expected {num_layers_per_pipeline_stage} expert_scores '
|
55 |
+
f'but found {len(tokens_per_expert)}.\nnum_layers = '
|
56 |
+
f'{args.num_layers}\npipeline_model_parallel_size = '
|
57 |
+
f'{args.pipeline_model_parallel_size}\n'
|
58 |
+
'num_layers_per_virtual_pipeline_stage'
|
59 |
+
f' = {args.num_layers_per_virtual_pipeline_stage}',
|
60 |
+
)
|
61 |
+
|
62 |
+
# Verify the shape of the tokens_per_expert and expert_scores tensors.
|
63 |
+
assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
|
64 |
+
|
65 |
+
tokens = expert_scores[0].shape[0]
|
66 |
+
assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
|
67 |
+
|
68 |
+
# Concatenate the contributions of each layer and convert to
|
69 |
+
# the correct types and formats for the dot product.
|
70 |
+
expert_scores = torch.cat(expert_scores, dim=1)
|
71 |
+
if args.moe_lbl_in_fp32:
|
72 |
+
expert_scores = expert_scores.float()
|
73 |
+
if tokens != 0:
|
74 |
+
expert_scores = expert_scores.mean(dim=0)
|
75 |
+
else:
|
76 |
+
expert_scores = expert_scores.sum(dim=0)
|
77 |
+
tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
|
78 |
+
|
79 |
+
expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
|
80 |
+
assert tokens_per_expert.numel() == expected_values
|
81 |
+
assert expert_scores.numel() == expected_values
|
82 |
+
|
83 |
+
# Calculate the total scale across all factors.
|
84 |
+
#
|
85 |
+
# loss_weight * num_experts / (num_layers * tokens * top_k)
|
86 |
+
scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
|
87 |
+
scale_denominator = (args.num_layers * tokens * args.moe_top_k)
|
88 |
+
scale = scale_numerator / scale_denominator
|
89 |
+
return scale * torch.dot(tokens_per_expert, expert_scores)
|
90 |
+
|
91 |
+
|
92 |
+
# NOTE: This class defines MoE expert computation, including expert model parallel
|
93 |
+
# communication. When using FSDP on top of MegaBlocks this is the module that should
|
94 |
+
# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
|
95 |
+
# parallel all2all.
|
96 |
+
class ParallelMLP(torch.nn.Module):
|
97 |
+
|
98 |
+
def __init__(self, args: Arguments):
|
99 |
+
super(ParallelMLP, self).__init__()
|
100 |
+
self.args = args
|
101 |
+
|
102 |
+
# Calculate the number of experts in total and the number of experts
|
103 |
+
# owned by this rank.
|
104 |
+
# world_size = mpu.get_expert_parallel_world_size(args)
|
105 |
+
self.num_experts = args.moe_num_experts
|
106 |
+
self.top_k = self.args.moe_top_k
|
107 |
+
|
108 |
+
# Calculate the number of bits needed to represent the expert indices
|
109 |
+
# so that we can pass it to radix sort.
|
110 |
+
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
111 |
+
|
112 |
+
# Expert MLP.
|
113 |
+
self.mlp = mlp.MLP(args)
|
114 |
+
|
115 |
+
self.bias: Optional[torch.Tensor]
|
116 |
+
if self.args.bias:
|
117 |
+
# Note that the output bias is not parallelized with expert
|
118 |
+
# model parallelism.
|
119 |
+
self.bias = torch.nn.Parameter(
|
120 |
+
torch.empty(
|
121 |
+
args.hidden_size,
|
122 |
+
device=args.device,
|
123 |
+
dtype=common.dtype(args),
|
124 |
+
),
|
125 |
+
)
|
126 |
+
torch.nn.init.zeros_(self.bias)
|
127 |
+
else:
|
128 |
+
self.register_parameter('bias', None)
|
129 |
+
|
130 |
+
# Select the forward function for the operating mode.
|
131 |
+
self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
|
132 |
+
|
133 |
+
def expert_capacity(self, tokens: int) -> int:
|
134 |
+
world_size = mpu.get_expert_parallel_world_size(self.args)
|
135 |
+
tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
|
136 |
+
return int(self.args.moe_capacity_factor * tokens_per_expert)
|
137 |
+
|
138 |
+
def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
|
139 |
+
"""Calculate the load balancing loss contribution."""
|
140 |
+
assert len(expert_scores.size()) == 2
|
141 |
+
tokens, num_experts = expert_scores.size()
|
142 |
+
assert num_experts == self.num_experts
|
143 |
+
assert len(tokens_per_expert.size()) == 1
|
144 |
+
num_experts, = tokens_per_expert.size()
|
145 |
+
assert num_experts == self.num_experts
|
146 |
+
scale = self.num_experts / (tokens * self.top_k)
|
147 |
+
return scale * torch.dot(
|
148 |
+
tokens_per_expert.to(expert_scores.dtype),
|
149 |
+
expert_scores.mean(dim=0),
|
150 |
+
)
|
151 |
+
|
152 |
+
def indices_and_bins(self,
|
153 |
+
top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
154 |
+
# Sort the expert ids to produce the scatter/gather
|
155 |
+
# indices for the permutation.
|
156 |
+
#
|
157 |
+
# TODO(tgale): Is it worth doing this conversion to 32-bit
|
158 |
+
# prior? Could we place the `torch.max` operation to return
|
159 |
+
# 32-bit expert indices?
|
160 |
+
top_expert = top_expert.int()
|
161 |
+
output = ops.sort(top_expert, self.sort_end_bit)
|
162 |
+
assert output is not None
|
163 |
+
bin_ids, indices = output
|
164 |
+
|
165 |
+
# Histogram the expert ids to identify the number of
|
166 |
+
# tokens routed to each expert.
|
167 |
+
#
|
168 |
+
# TODO(tgale): Does the sorted data produce a more favorable
|
169 |
+
# data distribution for histogram? Or is the op parallelism
|
170 |
+
# worth more?
|
171 |
+
tokens_per_expert = ops.histogram(top_expert, self.num_experts)
|
172 |
+
|
173 |
+
# Calculate the bin bounds for the sorted tokens.
|
174 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
175 |
+
assert bins is not None
|
176 |
+
bins = bins.view(1) if not len(bins.size()) else bins
|
177 |
+
|
178 |
+
assert isinstance(indices, torch.Tensor)
|
179 |
+
assert isinstance(bin_ids, torch.Tensor)
|
180 |
+
assert isinstance(bins, torch.Tensor)
|
181 |
+
assert isinstance(tokens_per_expert, torch.Tensor)
|
182 |
+
|
183 |
+
return indices, bin_ids, bins, tokens_per_expert
|
184 |
+
|
185 |
+
def permute_and_compute(
|
186 |
+
self,
|
187 |
+
x: torch.Tensor,
|
188 |
+
tokens_per_expert: int, # unused
|
189 |
+
indices: torch.Tensor,
|
190 |
+
bin_ids: torch.Tensor, # unused
|
191 |
+
expert_weights: torch.Tensor,
|
192 |
+
bins: torch.Tensor,
|
193 |
+
expert_capacity: int,
|
194 |
+
top_k: int,
|
195 |
+
):
|
196 |
+
# Route the tokens for MoE computation.
|
197 |
+
x = x.view(-1, x.shape[-1])
|
198 |
+
output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
|
199 |
+
assert output is not None
|
200 |
+
x = output
|
201 |
+
|
202 |
+
# Perform the expert computation. Note that we don't
|
203 |
+
# use biases for these linear operations.
|
204 |
+
x = self.mlp(x)
|
205 |
+
|
206 |
+
# Un-route the data for the MoE output.
|
207 |
+
return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
|
208 |
+
|
209 |
+
def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
210 |
+
# x: [sl, bs, hs]
|
211 |
+
# expert_weights: [sl * bs, top-k]
|
212 |
+
# top_experts: [sl * bs, top-k]
|
213 |
+
expert_weights = expert_weights.flatten()
|
214 |
+
top_experts = top_experts.flatten()
|
215 |
+
with torch.no_grad():
|
216 |
+
indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
|
217 |
+
|
218 |
+
# If expert_capacity is set to zero, set the number of tokens
|
219 |
+
# per expert to the maximum we need to avoid dropping tokens.
|
220 |
+
sl, bs, _ = x.size()
|
221 |
+
expert_capacity = self.expert_capacity(sl * bs)
|
222 |
+
if expert_capacity == 0:
|
223 |
+
expert_capacity = torch.max(tokens_per_expert).item()
|
224 |
+
|
225 |
+
x = self.permute_and_compute(
|
226 |
+
x,
|
227 |
+
tokens_per_expert,
|
228 |
+
indices,
|
229 |
+
bin_ids,
|
230 |
+
expert_weights,
|
231 |
+
bins,
|
232 |
+
expert_capacity,
|
233 |
+
self.top_k,
|
234 |
+
)
|
235 |
+
return x, tokens_per_expert
|
236 |
+
|
237 |
+
def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
238 |
+
# NOTE: This function implements the same computation as forward_once
|
239 |
+
# but with expert model parallelism.
|
240 |
+
#
|
241 |
+
# 1. Permute the tokens locally so that they are grouped by their
|
242 |
+
# expert assignments. This allows us to transfer all of the tokens
|
243 |
+
# for a remote device in one communication primitive.
|
244 |
+
#
|
245 |
+
# 2. Permute the tokens across the expert parallel devices. After
|
246 |
+
# this is completed each device has all of the tokens assigned to
|
247 |
+
# its set of experts in its local HBM.
|
248 |
+
#
|
249 |
+
# 3. Permute the tokens locally so that they are grouped by their
|
250 |
+
# expert assignement. After the distributed permutation the tokens
|
251 |
+
# are grouped by which device they came from. We re-order them
|
252 |
+
# locally to allow for efficient computation.
|
253 |
+
#
|
254 |
+
# After this series of permutations we compute the linear layers
|
255 |
+
# and then repeat these three steps in reverse to produce the final
|
256 |
+
# output.
|
257 |
+
#
|
258 |
+
# Compute the mapping of local tokens to experts.
|
259 |
+
expert_weights = expert_weights.flatten()
|
260 |
+
top_experts = top_experts.flatten()
|
261 |
+
with torch.no_grad():
|
262 |
+
indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
|
263 |
+
|
264 |
+
# If we're sharding the experts along the hidden dimension
|
265 |
+
# multiple devices own parts of the same sets of experts.
|
266 |
+
# Replicate the token counts so every device gets the counts.
|
267 |
+
repeated_tokens_per_expert = ops.repeat(
|
268 |
+
tokens_per_expert,
|
269 |
+
(mpu.hidden_sharding_degree(self.args),),
|
270 |
+
)
|
271 |
+
|
272 |
+
# Pass token count information to the device on which the
|
273 |
+
# target expert resides.
|
274 |
+
parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
|
275 |
+
tpe_handle = dist.all_to_all_single(
|
276 |
+
parallel_tokens_per_expert,
|
277 |
+
repeated_tokens_per_expert,
|
278 |
+
group=self.args.expert_parallel_group,
|
279 |
+
async_op=True,
|
280 |
+
)
|
281 |
+
|
282 |
+
# Permute locally and without any padding so that tokens for each
|
283 |
+
# parallel device are stored contiguously.
|
284 |
+
#
|
285 |
+
# This view updates the shape of the tensor from [sl, bs, hs] to
|
286 |
+
# [sl * bs, hs] prior to the permutation.
|
287 |
+
x = x.view(-1, x.shape[-1])
|
288 |
+
output = ops.gather(x, indices, bin_ids, bins, self.top_k)
|
289 |
+
assert output is not None
|
290 |
+
x = output
|
291 |
+
|
292 |
+
# Compute the number of tokens that will be received from each
|
293 |
+
# device and permute the input data across the devices.
|
294 |
+
with torch.no_grad():
|
295 |
+
tpe_handle.wait()
|
296 |
+
experts_per_rank = mpu.experts_per_rank(self.args)
|
297 |
+
|
298 |
+
# Reshape to [world_size, num_experts_per_rank].
|
299 |
+
world_size = mpu.get_expert_parallel_world_size(self.args)
|
300 |
+
repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
|
301 |
+
parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
|
302 |
+
|
303 |
+
# TODO(tgale): It might be faster to do this on the GPU and
|
304 |
+
# then communicate the results back to the host.
|
305 |
+
send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
|
306 |
+
parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
|
307 |
+
recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
|
308 |
+
|
309 |
+
# Convert the send/recv counts to lists.
|
310 |
+
send_counts = send_counts.tolist()
|
311 |
+
recv_counts = recv_counts.tolist()
|
312 |
+
tokens_received = sum(recv_counts)
|
313 |
+
|
314 |
+
# If we're sharding the experts along the hidden dimension
|
315 |
+
# multiple devices own parts of the same sets of experts.
|
316 |
+
# Replicate the token counts so devices that share experts
|
317 |
+
# get all of the tokens assigned to them.
|
318 |
+
#
|
319 |
+
# TODO(tgale): Fuse this into the prior, local permutation.
|
320 |
+
x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
|
321 |
+
|
322 |
+
# Start the cross-device permutation asynchronously so we can
|
323 |
+
# overlap communication with computation.
|
324 |
+
parallel_x, parallel_x_handle = all_to_all(
|
325 |
+
x,
|
326 |
+
recv_counts,
|
327 |
+
send_counts,
|
328 |
+
self.args.expert_parallel_group,
|
329 |
+
async_op=True,
|
330 |
+
)
|
331 |
+
|
332 |
+
with torch.no_grad():
|
333 |
+
# After we do the cross-device permutation we have the tokens on the
|
334 |
+
# correct device but not yet grouped by expert because we received
|
335 |
+
# tokens from each device as contiguous chunks. To group the tokens
|
336 |
+
# for expert computation we'll do one more local permutation. The
|
337 |
+
# rest of this torch.no_grad() scope sets up the indices and bins
|
338 |
+
# for this permutation.
|
339 |
+
replicate_bins = ops.inclusive_cumsum(
|
340 |
+
parallel_tokens_per_expert.flatten(),
|
341 |
+
0,
|
342 |
+
)
|
343 |
+
replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
|
344 |
+
|
345 |
+
# Construct the expert indices for the permuted tokens.
|
346 |
+
parallel_top_expert = torch.remainder(
|
347 |
+
torch.arange(
|
348 |
+
self.num_experts * mpu.hidden_sharding_degree(self.args),
|
349 |
+
dtype=torch.int32,
|
350 |
+
device=indices.device,
|
351 |
+
),
|
352 |
+
mpu.experts_per_rank(self.args),
|
353 |
+
)
|
354 |
+
parallel_top_expert = ops.replicate(
|
355 |
+
parallel_top_expert.unsqueeze(dim=0),
|
356 |
+
replicate_bins,
|
357 |
+
tokens_received,
|
358 |
+
).flatten()
|
359 |
+
|
360 |
+
# TODO(tgale): The sort_end_bit here can be reduced.
|
361 |
+
parallel_bin_ids, parallel_indices = ops.sort(
|
362 |
+
parallel_top_expert,
|
363 |
+
self.sort_end_bit,
|
364 |
+
)
|
365 |
+
|
366 |
+
# Calculate the bins boundaries from the token counts.
|
367 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
|
368 |
+
dim=0,
|
369 |
+
dtype=torch.int,
|
370 |
+
)
|
371 |
+
parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
|
372 |
+
parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
|
373 |
+
|
374 |
+
# If expert_capacity is set to zero, set the number of tokens
|
375 |
+
# per expert to the maximum we need to avoid dropping tokens.
|
376 |
+
tokens, _ = x.size()
|
377 |
+
expert_capacity = self.expert_capacity(tokens)
|
378 |
+
if expert_capacity == 0:
|
379 |
+
expert_capacity = torch.max(parallel_tokens_per_expert).item()
|
380 |
+
|
381 |
+
# Locally permute the tokens and perform the expert computation.
|
382 |
+
# Block to make sure that the cross-device permutation is complete.
|
383 |
+
if self.args.mlp_impl == 'grouped':
|
384 |
+
# GroupedMLP requires counts on CPU. We can use the tensor already
|
385 |
+
# moved to CPU for the prior all_to_all, which avoids an extra
|
386 |
+
# device synchronization.
|
387 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
|
388 |
+
dim=0,
|
389 |
+
dtype=torch.int,
|
390 |
+
)
|
391 |
+
parallel_x_handle.wait()
|
392 |
+
parallel_x = self.permute_and_compute(
|
393 |
+
parallel_x,
|
394 |
+
parallel_tokens_per_expert,
|
395 |
+
parallel_indices,
|
396 |
+
parallel_bin_ids,
|
397 |
+
None, # expert_weights
|
398 |
+
parallel_bins,
|
399 |
+
expert_capacity,
|
400 |
+
top_k=1,
|
401 |
+
)
|
402 |
+
|
403 |
+
# Un-permute the tokens across the devices.
|
404 |
+
x, _ = all_to_all(
|
405 |
+
parallel_x,
|
406 |
+
send_counts,
|
407 |
+
recv_counts,
|
408 |
+
self.args.expert_parallel_group,
|
409 |
+
)
|
410 |
+
|
411 |
+
# Reduce along the hidden sharding to get the final outputs.
|
412 |
+
#
|
413 |
+
# TODO(tgale): Fuse this into the following local permutation.
|
414 |
+
shape = (
|
415 |
+
mpu.hidden_sharding_degree(self.args),
|
416 |
+
-1,
|
417 |
+
self.args.hidden_size,
|
418 |
+
)
|
419 |
+
x = ops.sum(x.view(shape), dim=0)
|
420 |
+
|
421 |
+
# Un-permute locally to setup for the next series of operations.
|
422 |
+
x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
|
423 |
+
return x, tokens_per_expert.flatten()
|
424 |
+
|
425 |
+
def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
426 |
+
in_shape = x.size()
|
427 |
+
|
428 |
+
# Compute the experts.
|
429 |
+
x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
|
430 |
+
if self.training and self.args.moe_loss_weight > 0:
|
431 |
+
save_load_balancing_loss((tokens_per_expert, scores))
|
432 |
+
x = x.view(in_shape)
|
433 |
+
if self.bias is not None:
|
434 |
+
if self.args.return_bias:
|
435 |
+
return x, self.bias
|
436 |
+
return x + self.bias
|
437 |
+
return x
|
438 |
+
|
439 |
+
|
440 |
+
class MoE(torch.nn.Module):
|
441 |
+
|
442 |
+
def __init__(self, args: Arguments):
|
443 |
+
super(MoE, self).__init__()
|
444 |
+
|
445 |
+
# Token router.
|
446 |
+
self.router = router.LearnedRouter(args)
|
447 |
+
|
448 |
+
# Expert computation helper.
|
449 |
+
self.experts = self._init_experts_mlp(args)
|
450 |
+
|
451 |
+
self.shared_expert = None
|
452 |
+
if args.shared_expert:
|
453 |
+
# SharedExpert computation helper.
|
454 |
+
self.shared_expert = sharedexpert_registry.get(args)
|
455 |
+
|
456 |
+
def _init_experts_mlp(self, args: Arguments):
|
457 |
+
return ParallelMLP(args)
|
458 |
+
|
459 |
+
def forward(self, x: torch.Tensor):
|
460 |
+
# NOTE: If we're going to cast the activations to lower precision
|
461 |
+
# do it before we permute the tokens to save bandwidth.
|
462 |
+
x = common.cast_if_autocast_enabled(x)
|
463 |
+
|
464 |
+
# Compute the expert scores and assignments.
|
465 |
+
scores, expert_weights, top_experts = self.router(x)
|
466 |
+
|
467 |
+
# Compute the experts.
|
468 |
+
out = self.experts(x, scores, expert_weights, top_experts)
|
469 |
+
if self.shared_expert is not None:
|
470 |
+
shared_expert_out = self.shared_expert(x)
|
471 |
+
out = self.shared_expert.add_experts_sharedexpert(
|
472 |
+
shared_expert_out,
|
473 |
+
out,
|
474 |
+
)
|
475 |
+
return out
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
from megablocks.layers.arguments import Arguments
|
10 |
+
|
11 |
+
|
12 |
+
class MoeParam(torch.Tensor):
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__(self)
|
16 |
+
self.expert_model_parallel: bool
|
17 |
+
|
18 |
+
|
19 |
+
def is_moe_param(tensor: torch.Tensor) -> bool:
|
20 |
+
return hasattr(tensor, 'expert_model_parallel')
|
21 |
+
|
22 |
+
|
23 |
+
def get_expert_parallel_world_size(args: Arguments) -> int:
|
24 |
+
return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
|
25 |
+
|
26 |
+
|
27 |
+
def get_expert_parallel_rank(args: Arguments) -> int:
|
28 |
+
return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
|
29 |
+
|
30 |
+
|
31 |
+
def set_expert_model_parallel_attributes(
|
32 |
+
tensor: torch.Tensor,
|
33 |
+
is_parallel: bool,
|
34 |
+
):
|
35 |
+
assert not hasattr(tensor, 'expert_model_parallel')
|
36 |
+
setattr(tensor, 'expert_model_parallel', is_parallel)
|
37 |
+
|
38 |
+
|
39 |
+
def param_is_expert_model_parallel(param: MoeParam) -> bool:
|
40 |
+
return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
|
41 |
+
|
42 |
+
|
43 |
+
def copy_expert_model_parallel_attributes(
|
44 |
+
destination_tensor: torch.Tensor,
|
45 |
+
source_tensor: torch.Tensor,
|
46 |
+
):
|
47 |
+
if hasattr(source_tensor, 'expert_model_parallel'):
|
48 |
+
setattr(
|
49 |
+
destination_tensor,
|
50 |
+
'expert_model_parallel',
|
51 |
+
getattr(source_tensor, 'expert_model_parallel'),
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
|
56 |
+
world_size = dist.get_world_size(group)
|
57 |
+
rank = dist.get_rank(group)
|
58 |
+
for i in range(world_size):
|
59 |
+
dist.barrier(group)
|
60 |
+
if i == rank:
|
61 |
+
print(f'rank = {rank}', *x)
|
62 |
+
|
63 |
+
|
64 |
+
# Helpers for expert/tensor sharding.
|
65 |
+
def expert_sharding_degree(args: Arguments) -> int:
|
66 |
+
world_size = get_expert_parallel_world_size(args)
|
67 |
+
esd = min(world_size, args.moe_num_experts)
|
68 |
+
|
69 |
+
if (args.moe_num_experts % esd) != 0:
|
70 |
+
raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
|
71 |
+
return esd
|
72 |
+
|
73 |
+
|
74 |
+
def hidden_sharding_degree(args: Arguments) -> int:
|
75 |
+
world_size = get_expert_parallel_world_size(args)
|
76 |
+
esd = expert_sharding_degree(args)
|
77 |
+
hsd = world_size // esd
|
78 |
+
|
79 |
+
if (args.ffn_hidden_size % hsd) != 0:
|
80 |
+
raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
|
81 |
+
if (esd * hsd) != world_size:
|
82 |
+
raise ValueError(
|
83 |
+
f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
|
84 |
+
)
|
85 |
+
return hsd
|
86 |
+
|
87 |
+
|
88 |
+
def experts_per_rank(args: Arguments) -> int:
|
89 |
+
return args.moe_num_experts // expert_sharding_degree(args)
|
90 |
+
|
91 |
+
|
92 |
+
def features_per_rank(args: Arguments) -> int:
|
93 |
+
return args.ffn_hidden_size // hidden_sharding_degree(args)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from megablocks.layers import common
|
8 |
+
from megablocks.layers.arguments import Arguments
|
9 |
+
|
10 |
+
_ROUTER_LOGITS = []
|
11 |
+
|
12 |
+
|
13 |
+
def _save_router_logits(logits: torch.Tensor, args: Arguments):
|
14 |
+
if args.moe_zloss_weight == 0:
|
15 |
+
return
|
16 |
+
global _ROUTER_LOGITS
|
17 |
+
_ROUTER_LOGITS.append(logits)
|
18 |
+
|
19 |
+
|
20 |
+
def clear_router_zloss():
|
21 |
+
global _ROUTER_LOGITS
|
22 |
+
_ROUTER_LOGITS.clear()
|
23 |
+
|
24 |
+
|
25 |
+
def batched_router_zloss(args: Arguments):
|
26 |
+
global _ROUTER_LOGITS
|
27 |
+
|
28 |
+
if args.moe_zloss_weight == 0:
|
29 |
+
import warnings
|
30 |
+
warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
|
31 |
+
return 0
|
32 |
+
|
33 |
+
logits_per_router = _ROUTER_LOGITS
|
34 |
+
|
35 |
+
if args.moe_zloss_in_fp32:
|
36 |
+
logits_per_router = [logits.float() for logits in logits_per_router]
|
37 |
+
|
38 |
+
unscaled_zloss_per_router = torch.stack([
|
39 |
+
torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
|
40 |
+
])
|
41 |
+
|
42 |
+
return args.moe_zloss_weight * unscaled_zloss_per_router
|
43 |
+
|
44 |
+
|
45 |
+
# NOTE: To enable end-to-end benchmarking without convergence we
|
46 |
+
# support a flag to force the router to assign tokens uniformly
|
47 |
+
# across the experts. We do this with a custom autograd operation
|
48 |
+
# so that PyTorch still executes the full set of router operation.
|
49 |
+
class _UniformExpertAssignment(torch.autograd.Function):
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def forward(ctx: Any, x: torch.Tensor, num_experts: int):
|
53 |
+
out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
|
54 |
+
out = torch.remainder(out, num_experts)
|
55 |
+
return out.view(x.shape)
|
56 |
+
|
57 |
+
|
58 |
+
_uniform_expert_assignment = _UniformExpertAssignment.apply
|
59 |
+
|
60 |
+
|
61 |
+
class LearnedRouter(torch.nn.Module):
|
62 |
+
|
63 |
+
def __init__(self, args: Arguments):
|
64 |
+
super().__init__()
|
65 |
+
self.args = args
|
66 |
+
|
67 |
+
# Learned router parameters.
|
68 |
+
#
|
69 |
+
# NOTE: This weight matrix is not parallelized with expert model
|
70 |
+
# parallelism. Each device needs the entire router weight matrix
|
71 |
+
# so that it can route its batch of data correctly.
|
72 |
+
self.layer = torch.nn.Linear(
|
73 |
+
args.hidden_size,
|
74 |
+
args.moe_num_experts,
|
75 |
+
bias=False,
|
76 |
+
dtype=common.dtype(args),
|
77 |
+
device=args.device,
|
78 |
+
)
|
79 |
+
args.init_method(self.layer.weight)
|
80 |
+
|
81 |
+
def jitter(self, x: torch.Tensor):
|
82 |
+
low: float = 1.0 - self.args.moe_jitter_eps
|
83 |
+
high: float = 1.0 + self.args.moe_jitter_eps
|
84 |
+
noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
|
85 |
+
return low + noise * (high - low)
|
86 |
+
|
87 |
+
def _top_k(self, scores: torch.Tensor):
|
88 |
+
if self.args.moe_top_k == 1:
|
89 |
+
return scores.max(dim=-1, keepdim=True)
|
90 |
+
return torch.topk(scores, self.args.moe_top_k, dim=-1)
|
91 |
+
|
92 |
+
def forward(self, x: torch.Tensor):
|
93 |
+
if self.training and self.args.moe_jitter_eps is not None:
|
94 |
+
x = x * self.jitter(x)
|
95 |
+
|
96 |
+
logits = self.layer(x.view(-1, x.shape[-1]))
|
97 |
+
_save_router_logits(logits, self.args)
|
98 |
+
scores = logits.softmax(dim=-1)
|
99 |
+
expert_weights, expert_indices = self._top_k(scores)
|
100 |
+
if self.args.moe_normalize_expert_weights:
|
101 |
+
expert_weights = expert_weights / torch.norm(
|
102 |
+
expert_weights,
|
103 |
+
p=self.args.moe_normalize_expert_weights,
|
104 |
+
dim=-1,
|
105 |
+
keepdim=True,
|
106 |
+
)
|
107 |
+
|
108 |
+
expert_indices = (
|
109 |
+
_uniform_expert_assignment(
|
110 |
+
expert_indices,
|
111 |
+
self.args.moe_num_experts,
|
112 |
+
) if self.args.uniform_expert_assignment else expert_indices
|
113 |
+
)
|
114 |
+
return scores, expert_weights, expert_indices
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
from megablocks.layers import glu, mlp
|
7 |
+
from megablocks.layers.arguments import Arguments
|
8 |
+
|
9 |
+
_REGISTRY = {
|
10 |
+
'mlp': mlp.SharedMLP,
|
11 |
+
'glu': glu.SharedGLU,
|
12 |
+
}
|
13 |
+
|
14 |
+
|
15 |
+
def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
|
16 |
+
"""Returns an SharedMLP for use in a dMoE instance.
|
17 |
+
|
18 |
+
Uses the provided arguments to instantiate the appropriate
|
19 |
+
SharedMLP instance.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
args: propagated Arguments dataclass.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
An instantiated SharedMLP constructed using the input args.
|
26 |
+
"""
|
27 |
+
if args.mlp_type not in _REGISTRY:
|
28 |
+
raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
|
29 |
+
|
30 |
+
return _REGISTRY[args.mlp_type](args)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from megablocks.ops.binned_gather import binned_gather
|
5 |
+
from megablocks.ops.binned_scatter import binned_scatter
|
6 |
+
from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum
|
7 |
+
from megablocks.ops.gather import gather
|
8 |
+
from megablocks.ops.histogram import histogram
|
9 |
+
from megablocks.ops.padded_gather import padded_gather
|
10 |
+
from megablocks.ops.padded_scatter import padded_scatter
|
11 |
+
from megablocks.ops.repeat import repeat
|
12 |
+
from megablocks.ops.replicate import replicate
|
13 |
+
from megablocks.ops.round_up import round_up
|
14 |
+
from megablocks.ops.scatter import scatter
|
15 |
+
from megablocks.ops.sort import sort
|
16 |
+
from megablocks.ops.sum import sum
|
17 |
+
from megablocks.ops.topology import topology
|
18 |
+
|
19 |
+
__all__ = [
|
20 |
+
'binned_gather',
|
21 |
+
'binned_scatter',
|
22 |
+
'exclusive_cumsum',
|
23 |
+
'inclusive_cumsum',
|
24 |
+
'gather',
|
25 |
+
'histogram',
|
26 |
+
'padded_gather',
|
27 |
+
'padded_scatter',
|
28 |
+
'repeat',
|
29 |
+
'replicate',
|
30 |
+
'round_up',
|
31 |
+
'scatter',
|
32 |
+
'sort',
|
33 |
+
'sum',
|
34 |
+
'topology',
|
35 |
+
]
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
from megablocks import benchmark_util
|
8 |
+
from megablocks.layers.all_to_all import all_to_all
|
9 |
+
|
10 |
+
_ALL_TO_ALL_BENCHMARK = (
|
11 |
+
(8, 1024),
|
12 |
+
(16, 1024),
|
13 |
+
(32, 1024),
|
14 |
+
(64, 1024),
|
15 |
+
(128, 1024),
|
16 |
+
(256, 1024),
|
17 |
+
(512, 1024),
|
18 |
+
(1024, 1024),
|
19 |
+
(2 * 1024, 1024),
|
20 |
+
(4 * 1024, 1024),
|
21 |
+
(8 * 1024, 1024),
|
22 |
+
(16 * 1024, 1024),
|
23 |
+
(32 * 1024, 1024),
|
24 |
+
(64 * 1024, 1024),
|
25 |
+
(128 * 1024, 1024),
|
26 |
+
(256 * 1024, 1024),
|
27 |
+
(512 * 1024, 1024),
|
28 |
+
(1024 * 1024, 1024),
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def benchmark_all_to_all(group, sl, hs):
|
33 |
+
world_size = dist.get_world_size(group)
|
34 |
+
assert (sl % world_size) == 0
|
35 |
+
send_recv_sizes = [sl // world_size] * world_size
|
36 |
+
|
37 |
+
x = torch.randn((sl, hs)).cuda().half()
|
38 |
+
|
39 |
+
details = {
|
40 |
+
'world_size': world_size,
|
41 |
+
'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
|
42 |
+
}
|
43 |
+
|
44 |
+
def benchmark():
|
45 |
+
return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
|
46 |
+
|
47 |
+
time, std = benchmark_util.benchmark_function(benchmark)
|
48 |
+
|
49 |
+
if dist.get_rank(group) == 0:
|
50 |
+
benchmark_util.log_benchmark('All-To-All', details, time, std)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
assert dist.is_available()
|
55 |
+
group = dist.init_process_group(backend='nccl')
|
56 |
+
local_rank = dist.get_rank(group)
|
57 |
+
torch.cuda.set_device(local_rank)
|
58 |
+
|
59 |
+
for args in _ALL_TO_ALL_BENCHMARK:
|
60 |
+
benchmark_all_to_all(group, *args)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
+
|
8 |
+
from megablocks.backend import kernels
|
9 |
+
|
10 |
+
|
11 |
+
# Autograd wrapper for binned_gather kernel.
|
12 |
+
class BinnedGatherOp(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
@custom_fwd
|
16 |
+
def forward(
|
17 |
+
ctx: Any,
|
18 |
+
x: torch.Tensor,
|
19 |
+
indices: torch.Tensor,
|
20 |
+
bins: torch.Tensor,
|
21 |
+
bin_size: int,
|
22 |
+
top_k: int,
|
23 |
+
):
|
24 |
+
ctx.save_for_backward(indices, bins)
|
25 |
+
ctx.top_k = top_k
|
26 |
+
return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
@custom_bwd
|
30 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
31 |
+
grad = grad.contiguous()
|
32 |
+
indices, bins = ctx.saved_tensors
|
33 |
+
out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
|
34 |
+
return out, None, None, None, None
|
35 |
+
|
36 |
+
|
37 |
+
binned_gather = BinnedGatherOp.apply
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
+
|
8 |
+
from megablocks.backend import kernels
|
9 |
+
|
10 |
+
|
11 |
+
# Autograd wrapper for binned_scatter kernel.
|
12 |
+
class BinnedScatterOp(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
@custom_fwd
|
16 |
+
def forward(
|
17 |
+
ctx: Any,
|
18 |
+
x: torch.Tensor,
|
19 |
+
indices: torch.Tensor,
|
20 |
+
weights: torch.Tensor,
|
21 |
+
bins: torch.Tensor,
|
22 |
+
top_k: int,
|
23 |
+
):
|
24 |
+
assert len(x.size()) == 3
|
25 |
+
ctx.bin_size = x.size(1)
|
26 |
+
ctx.top_k = top_k
|
27 |
+
|
28 |
+
# TODO(tgale): Don't save 'x' for backwards if we don't need to
|
29 |
+
# calculate the gradient w.r.t. 'weights'.
|
30 |
+
ctx.save_for_backward(x, indices, weights, bins)
|
31 |
+
return kernels.binned_scatter(x, indices, weights, bins, top_k)
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
@custom_bwd
|
35 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
36 |
+
grad = grad.contiguous()
|
37 |
+
x, indices, weights, bins = ctx.saved_tensors
|
38 |
+
out = kernels.binned_gather(
|
39 |
+
grad,
|
40 |
+
indices,
|
41 |
+
weights,
|
42 |
+
bins,
|
43 |
+
ctx.bin_size,
|
44 |
+
ctx.top_k,
|
45 |
+
)
|
46 |
+
|
47 |
+
wgrad = None
|
48 |
+
if ctx.needs_input_grad[2]:
|
49 |
+
wgrad = kernels.binned_scatter_wgrad(
|
50 |
+
x,
|
51 |
+
grad,
|
52 |
+
indices,
|
53 |
+
bins,
|
54 |
+
ctx.top_k,
|
55 |
+
)
|
56 |
+
return out, None, wgrad, None, None
|
57 |
+
|
58 |
+
|
59 |
+
binned_scatter = BinnedScatterOp.apply
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
# NOTE: Torch needs to be imported before the custom
|
7 |
+
# extensions. Otherwise libc10.so cannot be found.
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Wrap this in a try-block with better error message and
|
11 |
+
# instructions for building the c++ operations.
|
12 |
+
try:
|
13 |
+
# import megablocks_ops as ops # type: ignore
|
14 |
+
from megablocks._ops import ops # type: ignore
|
15 |
+
except ModuleNotFoundError as e:
|
16 |
+
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
17 |
+
|
18 |
+
|
19 |
+
# Autograd wrappers for cumsum kernels.
|
20 |
+
# NOTE: Does not support gradients.
|
21 |
+
class ExclusiveCumsumOp(torch.autograd.Function):
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def forward(ctx: Any, x: torch.Tensor, dim: int):
|
25 |
+
if len(x.size()) == 1:
|
26 |
+
x = x.view([1, -1])
|
27 |
+
out = torch.empty_like(x)
|
28 |
+
ops.exclusive_cumsum(x, 1, out)
|
29 |
+
return out.squeeze()
|
30 |
+
out = torch.empty_like(x)
|
31 |
+
ops.exclusive_cumsum(x, dim, out)
|
32 |
+
return out
|
33 |
+
|
34 |
+
|
35 |
+
exclusive_cumsum = ExclusiveCumsumOp.apply
|
36 |
+
|
37 |
+
|
38 |
+
class InclusiveCumsumOp(torch.autograd.Function):
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
|
42 |
+
if len(x.size()) == 1:
|
43 |
+
x = x.view([1, -1])
|
44 |
+
out = torch.empty_like(x)
|
45 |
+
ops.inclusive_cumsum(x, 1, out)
|
46 |
+
return out.squeeze()
|
47 |
+
out = torch.empty_like(x)
|
48 |
+
ops.inclusive_cumsum(x, dim, out)
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
inclusive_cumsum = InclusiveCumsumOp.apply
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
+
|
8 |
+
from megablocks.backend import kernels
|
9 |
+
|
10 |
+
|
11 |
+
# Autograd wrapper for gather kernel.
|
12 |
+
class GatherOp(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
@custom_fwd
|
16 |
+
def forward(
|
17 |
+
ctx: Any,
|
18 |
+
x: torch.Tensor,
|
19 |
+
indices: torch.Tensor,
|
20 |
+
bin_ids: torch.Tensor,
|
21 |
+
bins: torch.Tensor,
|
22 |
+
top_k: int,
|
23 |
+
):
|
24 |
+
ctx.save_for_backward(indices, bin_ids, bins)
|
25 |
+
ctx.top_k = top_k
|
26 |
+
return kernels.gather(x, indices, bin_ids, None, bins, top_k)
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
@custom_bwd
|
30 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
31 |
+
grad = grad.contiguous()
|
32 |
+
|
33 |
+
indices, bin_ids, bins = ctx.saved_tensors
|
34 |
+
out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
|
35 |
+
return out, None, None, None, None, None
|
36 |
+
|
37 |
+
|
38 |
+
gather = GatherOp.apply
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
# NOTE: Torch needs to be imported before the custom
|
7 |
+
# extensions. Otherwise libc10.so cannot be found.
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Wrap this in a try-block with better error message and
|
11 |
+
# instructions for building the c++ operations.
|
12 |
+
try:
|
13 |
+
from megablocks._ops import ops # type: ignore
|
14 |
+
except ModuleNotFoundError as e:
|
15 |
+
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
+
|
17 |
+
|
18 |
+
# Autograd wrapper for histogram kernel.
|
19 |
+
# NOTE: Does not support gradients.
|
20 |
+
class HistogramOp(torch.autograd.Function):
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def forward(ctx: Any, x: torch.Tensor, max_val: float):
|
24 |
+
return ops.histogram(x, max_val)
|
25 |
+
|
26 |
+
|
27 |
+
histogram = HistogramOp.apply
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from absl.testing import parameterized
|
9 |
+
|
10 |
+
from megablocks import ops
|
11 |
+
|
12 |
+
_HISTOGRAM_TESTS = (
|
13 |
+
(16384, torch.int32, 2),
|
14 |
+
(16384, torch.int32, 4),
|
15 |
+
(16384, torch.int32, 8),
|
16 |
+
(16384, torch.int32, 16),
|
17 |
+
(16384, torch.int32, 32),
|
18 |
+
(16384, torch.int32, 64),
|
19 |
+
(16384, torch.int32, 128),
|
20 |
+
(16384, torch.int32, 256),
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def benchmark_function(fn, iterations=10):
|
25 |
+
# Run once to get rid of startup overhead.
|
26 |
+
fn()
|
27 |
+
times = []
|
28 |
+
for _ in range(iterations):
|
29 |
+
start = torch.cuda.Event(enable_timing=True)
|
30 |
+
end = torch.cuda.Event(enable_timing=True)
|
31 |
+
start.record()
|
32 |
+
fn()
|
33 |
+
end.record()
|
34 |
+
torch.cuda.synchronize()
|
35 |
+
times.append(start.elapsed_time(end))
|
36 |
+
times = np.array(times)
|
37 |
+
return times.mean(), times.std(), times.max(), times.min()
|
38 |
+
|
39 |
+
|
40 |
+
def log_benchmark(arguments, mean_t, std_t):
|
41 |
+
print('=' * 60)
|
42 |
+
print('Benchmark Parameters:')
|
43 |
+
for (key, value) in arguments.items():
|
44 |
+
print(f'{key} = {value}')
|
45 |
+
print('Results:')
|
46 |
+
print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
|
47 |
+
print('=' * 60)
|
48 |
+
|
49 |
+
|
50 |
+
class HistogramBenchmark(parameterized.TestCase):
|
51 |
+
|
52 |
+
@parameterized.parameters(*_HISTOGRAM_TESTS)
|
53 |
+
def testHistogram(self, n, dtype, max_val):
|
54 |
+
x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
|
55 |
+
|
56 |
+
mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
|
57 |
+
arguments = {
|
58 |
+
'n': n,
|
59 |
+
'dtype': dtype,
|
60 |
+
'max_val': max_val,
|
61 |
+
}
|
62 |
+
log_benchmark(arguments, mean_t, std_t)
|
63 |
+
|
64 |
+
@parameterized.parameters(*_HISTOGRAM_TESTS)
|
65 |
+
def testTorchHistogram(self, n, dtype, max_val):
|
66 |
+
x = torch.randint(0, 128, (n,)).cuda().to(dtype)
|
67 |
+
|
68 |
+
mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
|
69 |
+
arguments = {
|
70 |
+
'n': n,
|
71 |
+
'dtype': dtype,
|
72 |
+
'max_val': max_val,
|
73 |
+
}
|
74 |
+
log_benchmark(arguments, mean_t, std_t)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
unittest.main()
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import stk
|
7 |
+
import torch
|
8 |
+
from absl.testing import parameterized
|
9 |
+
|
10 |
+
from megablocks import benchmark_util, ops
|
11 |
+
|
12 |
+
|
13 |
+
# Calling tensor.t() calls tensor.transpose(0, 1) which calls
|
14 |
+
# torch.as_strided(...). Circumvent this chain to avoid an overhead
|
15 |
+
# this adds.
|
16 |
+
def transpose_view(x):
|
17 |
+
return torch.as_strided(
|
18 |
+
x,
|
19 |
+
(x.shape[1], x.shape[0]),
|
20 |
+
(x.stride()[1], x.stride()[0]),
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
_MATMUL_TESTS = (
|
25 |
+
(64 * 1024, 512, 2048, 64),
|
26 |
+
(32 * 1024, 768, 3072, 64),
|
27 |
+
(8 * 1024, 1024, 4096, 64),
|
28 |
+
(4 * 2048, 4096, 4 * 4096, 4),
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def log_benchmark(name, arguments, time, std, flops):
|
33 |
+
benchmark_util.log_benchmark(name, arguments, time, std)
|
34 |
+
print('flops = {:.2f}B'.format(flops / 1e9))
|
35 |
+
print('throughput = {:.2f}T'.format(flops / 1e9 / time))
|
36 |
+
print('=' * 60)
|
37 |
+
|
38 |
+
|
39 |
+
class MatmulBenchmark(parameterized.TestCase):
|
40 |
+
|
41 |
+
def build_sparse_matrix(self, x, padded_bins, fhs, ne):
|
42 |
+
blocking = 128
|
43 |
+
padded_tokens, _ = x.size()
|
44 |
+
assert padded_tokens % blocking == 0
|
45 |
+
assert fhs % blocking == 0
|
46 |
+
|
47 |
+
# Offsets for the sparse matrix. All rows have the
|
48 |
+
# same number of nonzero blocks dictated by the
|
49 |
+
# dimensionality of a single expert.
|
50 |
+
block_rows = padded_tokens // blocking
|
51 |
+
blocks_per_row = fhs // blocking
|
52 |
+
offsets = torch.arange(
|
53 |
+
0,
|
54 |
+
block_rows * blocks_per_row + 1,
|
55 |
+
blocks_per_row,
|
56 |
+
dtype=torch.int32,
|
57 |
+
device=x.device,
|
58 |
+
)
|
59 |
+
|
60 |
+
# Indices for the sparse matrix. The indices for
|
61 |
+
# the intermediate matrix are dynamic depending
|
62 |
+
# on the mapping of tokens to experts.
|
63 |
+
column_indices = ops.topology(
|
64 |
+
padded_bins,
|
65 |
+
blocking,
|
66 |
+
block_rows,
|
67 |
+
blocks_per_row,
|
68 |
+
)
|
69 |
+
data = torch.empty(
|
70 |
+
column_indices.numel(),
|
71 |
+
blocking,
|
72 |
+
blocking,
|
73 |
+
dtype=torch.float16,
|
74 |
+
device=x.device,
|
75 |
+
)
|
76 |
+
shape = (padded_tokens, fhs * ne)
|
77 |
+
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
78 |
+
return stk.Matrix(shape, data, row_indices, column_indices, offsets)
|
79 |
+
|
80 |
+
def build_input_matrix(self, sl, hs, ne):
|
81 |
+
x = torch.randn((sl, hs)).cuda().half()
|
82 |
+
|
83 |
+
# Assign tokens to experts uniformly.
|
84 |
+
top_expert = torch.arange(0, sl).cuda().int() % ne
|
85 |
+
|
86 |
+
bin_ids, indices = ops.sort(top_expert)
|
87 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
88 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
89 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
90 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
91 |
+
out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
|
92 |
+
return out, padded_bins
|
93 |
+
|
94 |
+
def build_weight_matrix(self, ne, hs, fhs):
|
95 |
+
return torch.randn((hs, ne * fhs)).cuda().half()
|
96 |
+
|
97 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
98 |
+
def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
|
99 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
100 |
+
w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
|
101 |
+
topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
102 |
+
w = transpose_view(w)
|
103 |
+
|
104 |
+
def benchmark():
|
105 |
+
return stk.ops.sdd(x, w, topo)
|
106 |
+
|
107 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
108 |
+
arguments = {
|
109 |
+
'sequence_length': sl,
|
110 |
+
'hidden_size': hs,
|
111 |
+
'ffn_hidden_size': fhs,
|
112 |
+
'num_experts': ne,
|
113 |
+
}
|
114 |
+
log_benchmark(
|
115 |
+
'0::Fwd::SDD::NT',
|
116 |
+
arguments,
|
117 |
+
mean_t,
|
118 |
+
std_t,
|
119 |
+
x.numel() * fhs * 2,
|
120 |
+
)
|
121 |
+
|
122 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
123 |
+
def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
|
124 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
125 |
+
w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
|
126 |
+
topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
127 |
+
|
128 |
+
def benchmark():
|
129 |
+
return stk.ops.dsd(topo, w)
|
130 |
+
|
131 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
132 |
+
arguments = {
|
133 |
+
'sequence_length': sl,
|
134 |
+
'hidden_size': hs,
|
135 |
+
'ffn_hidden_size': fhs,
|
136 |
+
'num_experts': ne,
|
137 |
+
}
|
138 |
+
log_benchmark(
|
139 |
+
'0::GradX::DSD::NN',
|
140 |
+
arguments,
|
141 |
+
mean_t,
|
142 |
+
std_t,
|
143 |
+
x.numel() * fhs * 2,
|
144 |
+
)
|
145 |
+
|
146 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
147 |
+
def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
|
148 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
149 |
+
topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
150 |
+
topo = topo.t()
|
151 |
+
|
152 |
+
def benchmark():
|
153 |
+
return stk.ops.dsd(topo, x)
|
154 |
+
|
155 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
156 |
+
arguments = {
|
157 |
+
'sequence_length': sl,
|
158 |
+
'hidden_size': hs,
|
159 |
+
'ffn_hidden_size': fhs,
|
160 |
+
'num_experts': ne,
|
161 |
+
}
|
162 |
+
log_benchmark(
|
163 |
+
'0::GradW::DSD::TN',
|
164 |
+
arguments,
|
165 |
+
mean_t,
|
166 |
+
std_t,
|
167 |
+
x.numel() * fhs * 2,
|
168 |
+
)
|
169 |
+
|
170 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
171 |
+
def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
|
172 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
173 |
+
w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
|
174 |
+
x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
175 |
+
|
176 |
+
def benchmark():
|
177 |
+
return stk.ops.dsd(x, w)
|
178 |
+
|
179 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
180 |
+
arguments = {
|
181 |
+
'sequence_length': sl,
|
182 |
+
'hidden_size': hs,
|
183 |
+
'ffn_hidden_size': fhs,
|
184 |
+
'num_experts': ne,
|
185 |
+
}
|
186 |
+
log_benchmark(
|
187 |
+
'1::Fwd::DSD::NN',
|
188 |
+
arguments,
|
189 |
+
mean_t,
|
190 |
+
std_t,
|
191 |
+
x.nnz * hs * 2,
|
192 |
+
)
|
193 |
+
|
194 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
195 |
+
def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
|
196 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
197 |
+
w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
|
198 |
+
x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
199 |
+
out = stk.ops.dsd(x, w)
|
200 |
+
w = transpose_view(w)
|
201 |
+
|
202 |
+
def benchmark():
|
203 |
+
return stk.ops.sdd(out, w, x)
|
204 |
+
|
205 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
206 |
+
arguments = {
|
207 |
+
'sequence_length': sl,
|
208 |
+
'hidden_size': hs,
|
209 |
+
'ffn_hidden_size': fhs,
|
210 |
+
'num_experts': ne,
|
211 |
+
}
|
212 |
+
log_benchmark(
|
213 |
+
'1::GradX::SDD::NT',
|
214 |
+
arguments,
|
215 |
+
mean_t,
|
216 |
+
std_t,
|
217 |
+
x.nnz * hs * 2,
|
218 |
+
)
|
219 |
+
|
220 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
221 |
+
def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
|
222 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
223 |
+
w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
|
224 |
+
x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
225 |
+
out = stk.ops.dsd(x, w)
|
226 |
+
x = x.t()
|
227 |
+
|
228 |
+
def benchmark():
|
229 |
+
return stk.ops.dsd(x, out)
|
230 |
+
|
231 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
232 |
+
arguments = {
|
233 |
+
'sequence_length': sl,
|
234 |
+
'hidden_size': hs,
|
235 |
+
'ffn_hidden_size': fhs,
|
236 |
+
'num_experts': ne,
|
237 |
+
}
|
238 |
+
log_benchmark(
|
239 |
+
'1::GradW::DSD::TN',
|
240 |
+
arguments,
|
241 |
+
mean_t,
|
242 |
+
std_t,
|
243 |
+
x.nnz * hs * 2,
|
244 |
+
)
|
245 |
+
|
246 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
247 |
+
def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
|
248 |
+
assert (sl % ne) == 0
|
249 |
+
x = torch.randn((ne, sl // ne, hs)).cuda().half()
|
250 |
+
w = torch.randn((ne, hs, fhs)).cuda().half()
|
251 |
+
|
252 |
+
w = w.transpose(1, 2).contiguous()
|
253 |
+
w = w.transpose(1, 2)
|
254 |
+
|
255 |
+
def benchmark():
|
256 |
+
return torch.bmm(x, w)
|
257 |
+
|
258 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
259 |
+
arguments = {
|
260 |
+
'sequence_length': sl,
|
261 |
+
'hidden_size': hs,
|
262 |
+
'ffn_hidden_size': fhs,
|
263 |
+
'num_experts': ne,
|
264 |
+
}
|
265 |
+
log_benchmark(
|
266 |
+
'0::Fwd:DDD::NT',
|
267 |
+
arguments,
|
268 |
+
mean_t,
|
269 |
+
std_t,
|
270 |
+
x.numel() * fhs * 2,
|
271 |
+
)
|
272 |
+
|
273 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
274 |
+
def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
|
275 |
+
assert (sl % ne) == 0
|
276 |
+
x = torch.randn((ne, sl // ne, hs)).cuda().half()
|
277 |
+
w = torch.randn((ne, hs, fhs)).cuda().half()
|
278 |
+
out = torch.bmm(x, w)
|
279 |
+
w = w.transpose(1, 2).contiguous()
|
280 |
+
|
281 |
+
def benchmark():
|
282 |
+
return torch.bmm(out, w)
|
283 |
+
|
284 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
285 |
+
arguments = {
|
286 |
+
'sequence_length': sl,
|
287 |
+
'hidden_size': hs,
|
288 |
+
'ffn_hidden_size': fhs,
|
289 |
+
'num_experts': ne,
|
290 |
+
}
|
291 |
+
log_benchmark(
|
292 |
+
'0:GradX:DDD::NN',
|
293 |
+
arguments,
|
294 |
+
mean_t,
|
295 |
+
std_t,
|
296 |
+
x.numel() * fhs * 2,
|
297 |
+
)
|
298 |
+
|
299 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
300 |
+
def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
|
301 |
+
assert (sl % ne) == 0
|
302 |
+
x = torch.randn((ne, sl // ne, hs)).cuda().half()
|
303 |
+
w = torch.randn((ne, hs, fhs)).cuda().half()
|
304 |
+
out = torch.bmm(x, w)
|
305 |
+
out = out.transpose(1, 2)
|
306 |
+
|
307 |
+
def benchmark():
|
308 |
+
return torch.bmm(out, x)
|
309 |
+
|
310 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
311 |
+
arguments = {
|
312 |
+
'sequence_length': sl,
|
313 |
+
'hidden_size': hs,
|
314 |
+
'ffn_hidden_size': fhs,
|
315 |
+
'num_experts': ne,
|
316 |
+
}
|
317 |
+
log_benchmark(
|
318 |
+
'0:GradW:DDD::TN',
|
319 |
+
arguments,
|
320 |
+
mean_t,
|
321 |
+
std_t,
|
322 |
+
x.numel() * fhs * 2,
|
323 |
+
)
|
324 |
+
|
325 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
326 |
+
def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
|
327 |
+
assert (sl % ne) == 0
|
328 |
+
x = torch.randn((ne, sl // ne, fhs)).cuda().half()
|
329 |
+
w = torch.randn((ne, fhs, hs)).cuda().half()
|
330 |
+
|
331 |
+
def benchmark():
|
332 |
+
return torch.bmm(x, w)
|
333 |
+
|
334 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
335 |
+
arguments = {
|
336 |
+
'sequence_length': sl,
|
337 |
+
'hidden_size': hs,
|
338 |
+
'ffn_hidden_size': fhs,
|
339 |
+
'num_experts': ne,
|
340 |
+
}
|
341 |
+
log_benchmark(
|
342 |
+
'1::Fwd::DDD::NN',
|
343 |
+
arguments,
|
344 |
+
mean_t,
|
345 |
+
std_t,
|
346 |
+
x.numel() * hs * 2,
|
347 |
+
)
|
348 |
+
|
349 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
350 |
+
def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
|
351 |
+
assert (sl % ne) == 0
|
352 |
+
x = torch.randn((ne, sl // ne, fhs)).cuda().half()
|
353 |
+
w = torch.randn((ne, fhs, hs)).cuda().half()
|
354 |
+
out = torch.bmm(x, w)
|
355 |
+
w = torch.transpose(w, 1, 2)
|
356 |
+
|
357 |
+
def benchmark():
|
358 |
+
return torch.bmm(out, w)
|
359 |
+
|
360 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
361 |
+
arguments = {
|
362 |
+
'sequence_length': sl,
|
363 |
+
'hidden_size': hs,
|
364 |
+
'ffn_hidden_size': fhs,
|
365 |
+
'num_experts': ne,
|
366 |
+
}
|
367 |
+
log_benchmark(
|
368 |
+
'1::GradX::DDD::NT',
|
369 |
+
arguments,
|
370 |
+
mean_t,
|
371 |
+
std_t,
|
372 |
+
x.numel() * hs * 2,
|
373 |
+
)
|
374 |
+
|
375 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
376 |
+
def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
|
377 |
+
assert (sl % ne) == 0
|
378 |
+
x = torch.randn((ne, sl // ne, fhs)).cuda().half()
|
379 |
+
w = torch.randn((ne, fhs, hs)).cuda().half()
|
380 |
+
out = torch.bmm(x, w)
|
381 |
+
x = torch.transpose(x, 1, 2)
|
382 |
+
|
383 |
+
def benchmark():
|
384 |
+
return torch.bmm(x, out)
|
385 |
+
|
386 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
387 |
+
arguments = {
|
388 |
+
'sequence_length': sl,
|
389 |
+
'hidden_size': hs,
|
390 |
+
'ffn_hidden_size': fhs,
|
391 |
+
'num_experts': ne,
|
392 |
+
}
|
393 |
+
log_benchmark(
|
394 |
+
'1::GradW::DDD::TN',
|
395 |
+
arguments,
|
396 |
+
mean_t,
|
397 |
+
std_t,
|
398 |
+
x.numel() * hs * 2,
|
399 |
+
)
|
400 |
+
|
401 |
+
|
402 |
+
if __name__ == '__main__':
|
403 |
+
unittest.main()
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
+
|
8 |
+
from megablocks.backend import kernels
|
9 |
+
|
10 |
+
|
11 |
+
# Autograd wrapper for padded_gather kernel.
|
12 |
+
class PaddedGatherOp(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
@custom_fwd
|
16 |
+
def forward(
|
17 |
+
ctx: Any,
|
18 |
+
x: torch.Tensor,
|
19 |
+
indices: torch.Tensor,
|
20 |
+
bin_ids: torch.Tensor,
|
21 |
+
bins: torch.Tensor,
|
22 |
+
padded_bins: torch.Tensor,
|
23 |
+
top_k: int,
|
24 |
+
):
|
25 |
+
ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
|
26 |
+
ctx.top_k = top_k
|
27 |
+
return kernels.padded_gather(
|
28 |
+
x,
|
29 |
+
indices,
|
30 |
+
bin_ids,
|
31 |
+
None,
|
32 |
+
bins,
|
33 |
+
padded_bins,
|
34 |
+
top_k,
|
35 |
+
)
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
@custom_bwd
|
39 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
40 |
+
grad = grad.contiguous()
|
41 |
+
|
42 |
+
indices, bin_ids, bins, padded_bins = ctx.saved_tensors
|
43 |
+
out = kernels.padded_scatter(
|
44 |
+
grad,
|
45 |
+
indices,
|
46 |
+
bin_ids,
|
47 |
+
None,
|
48 |
+
bins,
|
49 |
+
padded_bins,
|
50 |
+
ctx.top_k,
|
51 |
+
)
|
52 |
+
return out, None, None, None, None, None
|
53 |
+
|
54 |
+
|
55 |
+
padded_gather = PaddedGatherOp.apply
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
+
|
8 |
+
from megablocks.backend import kernels
|
9 |
+
|
10 |
+
|
11 |
+
# Autograd wrapper for padded_scatter kernel.
|
12 |
+
class PaddedScatterOp(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
@custom_fwd
|
16 |
+
def forward(
|
17 |
+
ctx: Any,
|
18 |
+
x: torch.Tensor,
|
19 |
+
indices: torch.Tensor,
|
20 |
+
bin_ids: torch.Tensor,
|
21 |
+
weights: torch.Tensor,
|
22 |
+
bins: torch.Tensor,
|
23 |
+
padded_bins: torch.Tensor,
|
24 |
+
top_k: int,
|
25 |
+
):
|
26 |
+
maybe_x = [x] if ctx.needs_input_grad[3] else []
|
27 |
+
ctx.save_for_backward(
|
28 |
+
indices,
|
29 |
+
bin_ids,
|
30 |
+
weights,
|
31 |
+
bins,
|
32 |
+
padded_bins,
|
33 |
+
*maybe_x,
|
34 |
+
)
|
35 |
+
ctx.top_k = top_k
|
36 |
+
ctx.x_shape = x.shape
|
37 |
+
return kernels.padded_scatter(
|
38 |
+
x,
|
39 |
+
indices,
|
40 |
+
bin_ids,
|
41 |
+
weights,
|
42 |
+
bins,
|
43 |
+
padded_bins,
|
44 |
+
top_k,
|
45 |
+
)
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
@custom_bwd
|
49 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
50 |
+
grad = grad.contiguous()
|
51 |
+
saved_tensors = ctx.saved_tensors
|
52 |
+
|
53 |
+
indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
|
54 |
+
dgrad = None
|
55 |
+
if ctx.needs_input_grad[0]:
|
56 |
+
dgrad = kernels.padded_gather(
|
57 |
+
grad,
|
58 |
+
indices,
|
59 |
+
bin_ids,
|
60 |
+
weights,
|
61 |
+
bins,
|
62 |
+
padded_bins,
|
63 |
+
ctx.top_k,
|
64 |
+
)
|
65 |
+
|
66 |
+
wgrad = None
|
67 |
+
if ctx.needs_input_grad[3]: # need wgrad
|
68 |
+
x = saved_tensors[-1]
|
69 |
+
wgrad = kernels.padded_scatter_wgrad(
|
70 |
+
x,
|
71 |
+
grad,
|
72 |
+
indices,
|
73 |
+
bin_ids,
|
74 |
+
bins,
|
75 |
+
padded_bins,
|
76 |
+
ctx.top_k,
|
77 |
+
)
|
78 |
+
return dgrad, None, None, wgrad, None, None, None, None
|
79 |
+
|
80 |
+
|
81 |
+
def padded_scatter(
|
82 |
+
x: torch.Tensor,
|
83 |
+
indices: torch.Tensor,
|
84 |
+
bin_ids: torch.Tensor,
|
85 |
+
weights: torch.Tensor,
|
86 |
+
bins: torch.Tensor,
|
87 |
+
padded_bins: torch.Tensor,
|
88 |
+
top_k: int,
|
89 |
+
):
|
90 |
+
return PaddedScatterOp.apply(
|
91 |
+
x,
|
92 |
+
indices,
|
93 |
+
bin_ids,
|
94 |
+
weights,
|
95 |
+
bins,
|
96 |
+
padded_bins,
|
97 |
+
top_k,
|
98 |
+
)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from absl.testing import parameterized
|
8 |
+
|
9 |
+
from megablocks import benchmark_util, ops
|
10 |
+
|
11 |
+
_PADDED_SCATTER_BENCHMARK = (
|
12 |
+
# dMoE-Medium, 8-way EMP.
|
13 |
+
(1024 * 16, 1024, 8, 4),
|
14 |
+
# dMoE-Medium, post-all-to-all.
|
15 |
+
(1024 * 16 * 4, 1024, 8, 1),
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class PaddedScatterTest(parameterized.TestCase):
|
20 |
+
|
21 |
+
@parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
|
22 |
+
def testPaddedScatter(self, sl, hs, ne, top_k):
|
23 |
+
# Create the data and indices.
|
24 |
+
x = torch.randn((sl, hs)).cuda().half()
|
25 |
+
|
26 |
+
# Randomly assign tokens to experts.
|
27 |
+
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
|
28 |
+
bin_ids, indices = ops.sort(top_expert)
|
29 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
30 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
31 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
32 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
33 |
+
|
34 |
+
# Sample weights for the scatter reduce.
|
35 |
+
weights = torch.rand((sl * top_k,)).cuda().half()
|
36 |
+
|
37 |
+
# Gather the data to prepare for backwards.
|
38 |
+
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
|
39 |
+
|
40 |
+
def benchmark():
|
41 |
+
return ops.padded_scatter(
|
42 |
+
x,
|
43 |
+
indices,
|
44 |
+
bin_ids,
|
45 |
+
weights,
|
46 |
+
bins,
|
47 |
+
padded_bins,
|
48 |
+
top_k,
|
49 |
+
)
|
50 |
+
|
51 |
+
time, std = benchmark_util.benchmark_function(benchmark)
|
52 |
+
benchmark_util.log_benchmark(
|
53 |
+
'Padded Scatter',
|
54 |
+
{
|
55 |
+
'sequence_length': sl,
|
56 |
+
'hidden_size': hs,
|
57 |
+
'num_experts': ne,
|
58 |
+
'top_k': top_k,
|
59 |
+
},
|
60 |
+
time,
|
61 |
+
std,
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == '__main__':
|
66 |
+
unittest.main()
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from absl.testing import parameterized
|
8 |
+
|
9 |
+
from megablocks import benchmark_util, ops
|
10 |
+
|
11 |
+
_PERMUTE_TESTS = (
|
12 |
+
(16384, 768, 2),
|
13 |
+
(16384, 768, 4),
|
14 |
+
(16384, 768, 8),
|
15 |
+
(16384, 768, 16),
|
16 |
+
(16384, 768, 32),
|
17 |
+
(16384, 768, 64),
|
18 |
+
(16384, 768, 128),
|
19 |
+
(16384 * 8, 768, 2),
|
20 |
+
(16384 * 8, 768, 4),
|
21 |
+
(16384 * 8, 768, 8),
|
22 |
+
(16384 * 8, 768, 16),
|
23 |
+
(16384 * 8, 768, 32),
|
24 |
+
(16384 * 8, 768, 64),
|
25 |
+
(16384 * 8, 768, 128),
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class PermuteBenchmark(parameterized.TestCase):
|
30 |
+
|
31 |
+
@parameterized.parameters(*_PERMUTE_TESTS)
|
32 |
+
def testBinnedGather(self, sl, hs, ne):
|
33 |
+
# NOTE: Capacity factor == 1.
|
34 |
+
ec = sl // ne
|
35 |
+
|
36 |
+
# Create the data and indices.
|
37 |
+
x = torch.randn((sl, hs)).cuda().half()
|
38 |
+
top_expert = torch.randint(0, ne, (sl,)).cuda().int()
|
39 |
+
bin_ids, indices = ops.sort(top_expert)
|
40 |
+
tokens_per_expert = ops.histogram(indices, ne)
|
41 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
42 |
+
|
43 |
+
def benchmark():
|
44 |
+
return ops.binned_gather(x, indices, bins, ec)
|
45 |
+
|
46 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
47 |
+
arguments = {
|
48 |
+
'sequence_length': sl,
|
49 |
+
'hidden_size': hs,
|
50 |
+
'num_experts': ne,
|
51 |
+
}
|
52 |
+
benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
|
53 |
+
|
54 |
+
@parameterized.parameters(*_PERMUTE_TESTS)
|
55 |
+
def testBinnedScatter(self, sl, hs, ne):
|
56 |
+
# NOTE: Capacity factor == 1.
|
57 |
+
ec = sl // ne
|
58 |
+
|
59 |
+
# Create the data and indices.
|
60 |
+
x = torch.randn((sl, hs)).cuda().half()
|
61 |
+
top_expert = torch.randint(0, ne, (sl,)).cuda().int()
|
62 |
+
bin_ids, indices = ops.sort(top_expert)
|
63 |
+
tokens_per_expert = ops.histogram(indices, ne)
|
64 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
65 |
+
x = ops.binned_gather(x, indices, bins, ec)
|
66 |
+
|
67 |
+
def benchmark():
|
68 |
+
return ops.binned_scatter(x, indices, bins)
|
69 |
+
|
70 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
71 |
+
arguments = {
|
72 |
+
'sequence_length': sl,
|
73 |
+
'hidden_size': hs,
|
74 |
+
'num_experts': ne,
|
75 |
+
}
|
76 |
+
benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
|
77 |
+
|
78 |
+
@parameterized.parameters(*_PERMUTE_TESTS)
|
79 |
+
def testPaddedGather(self, sl, hs, ne):
|
80 |
+
# Create the data and indices.
|
81 |
+
x = torch.randn((sl, hs)).cuda().half()
|
82 |
+
|
83 |
+
# Randomly assign tokens to experts.
|
84 |
+
top_expert = torch.randint(0, ne, (sl,)).cuda().int()
|
85 |
+
bin_ids, indices = ops.sort(top_expert)
|
86 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
87 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
88 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
89 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
90 |
+
|
91 |
+
def benchmark():
|
92 |
+
return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
|
93 |
+
|
94 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
95 |
+
arguments = {
|
96 |
+
'sequence_length': sl,
|
97 |
+
'hidden_size': hs,
|
98 |
+
'num_experts': ne,
|
99 |
+
}
|
100 |
+
benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
|
101 |
+
|
102 |
+
@parameterized.parameters(*_PERMUTE_TESTS)
|
103 |
+
def testPaddedScatter(self, sl, hs, ne):
|
104 |
+
# Create the data and indices.
|
105 |
+
x = torch.randn((sl, hs)).cuda().half()
|
106 |
+
|
107 |
+
# Randomly assign tokens to experts.
|
108 |
+
top_expert = torch.randint(0, ne, (sl,)).cuda().int()
|
109 |
+
bin_ids, indices = ops.sort(top_expert)
|
110 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
111 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
112 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
113 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
114 |
+
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
|
115 |
+
|
116 |
+
def benchmark():
|
117 |
+
return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
|
118 |
+
|
119 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
120 |
+
arguments = {
|
121 |
+
'sequence_length': sl,
|
122 |
+
'hidden_size': hs,
|
123 |
+
'num_experts': ne,
|
124 |
+
}
|
125 |
+
benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
|
126 |
+
|
127 |
+
@parameterized.parameters(*_PERMUTE_TESTS)
|
128 |
+
def testCopy(self, sl, hs, ne):
|
129 |
+
# NOTE: Capacity factor == 1.
|
130 |
+
# ec = sl // ne
|
131 |
+
|
132 |
+
# Create the data and indices.
|
133 |
+
x = torch.randn((sl, hs)).cuda().half()
|
134 |
+
y = x.clone()
|
135 |
+
|
136 |
+
def benchmark():
|
137 |
+
return y.copy_(x)
|
138 |
+
|
139 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
140 |
+
arguments = {
|
141 |
+
'sequence_length': sl,
|
142 |
+
'hidden_size': hs,
|
143 |
+
'num_experts': ne,
|
144 |
+
}
|
145 |
+
benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == '__main__':
|
149 |
+
unittest.main()
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def repeat(x: torch.Tensor, tiling: torch.Size):
|
8 |
+
if all((t == 1 for t in tiling)):
|
9 |
+
return x
|
10 |
+
return x.repeat(*tiling)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
# NOTE: Torch needs to be imported before the custom
|
7 |
+
# extensions. Otherwise libc10.so cannot be found.
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Wrap this in a try-block with better error message and
|
11 |
+
# instructions for building the c++ operations.
|
12 |
+
try:
|
13 |
+
from megablocks._ops import ops # type: ignore
|
14 |
+
except ModuleNotFoundError as e:
|
15 |
+
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
+
|
17 |
+
|
18 |
+
# Autograd wrapper for replicate kernel.
|
19 |
+
class ReplicateOp(torch.autograd.Function):
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
|
23 |
+
ctx.save_for_backward(bins)
|
24 |
+
out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
|
25 |
+
ops.replicate_forward(x, bins, out)
|
26 |
+
return out
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
30 |
+
bins, = ctx.saved_tensors
|
31 |
+
out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
|
32 |
+
ops.replicate_backward(grad, bins, out)
|
33 |
+
return out, None, None
|
34 |
+
|
35 |
+
|
36 |
+
replicate = ReplicateOp.apply
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def round_up(x: torch.Tensor, value: int):
|
8 |
+
assert isinstance(value, int)
|
9 |
+
assert x.dtype == torch.int32
|
10 |
+
|
11 |
+
# TODO(tgale): If this becomes and issue
|
12 |
+
# do this in a custom kernel. We only expect
|
13 |
+
# to use this on arrays of less than 1k elements.
|
14 |
+
return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from stk.backend.autocast import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
from megablocks.backend import kernels
|
10 |
+
|
11 |
+
|
12 |
+
# Autograd wrapper for scatter kernel.
|
13 |
+
class ScatterOp(torch.autograd.Function):
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
@custom_fwd
|
17 |
+
def forward(
|
18 |
+
ctx: Any,
|
19 |
+
x: torch.Tensor,
|
20 |
+
indices: torch.Tensor,
|
21 |
+
bin_ids: torch.Tensor,
|
22 |
+
weights: torch.Tensor,
|
23 |
+
bins: torch.Tensor,
|
24 |
+
top_k: int,
|
25 |
+
) -> torch.Tensor:
|
26 |
+
maybe_x = [x] if ctx.needs_input_grad[3] else []
|
27 |
+
ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
|
28 |
+
ctx.top_k = top_k
|
29 |
+
ctx.x_shape = x.shape
|
30 |
+
return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
@custom_bwd
|
34 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
35 |
+
grad = grad.contiguous()
|
36 |
+
saved_tensors = ctx.saved_tensors
|
37 |
+
|
38 |
+
indices, bin_ids, weights, bins = saved_tensors[:4]
|
39 |
+
dgrad = None
|
40 |
+
if ctx.needs_input_grad[0]:
|
41 |
+
dgrad = kernels.gather(
|
42 |
+
grad,
|
43 |
+
indices,
|
44 |
+
bin_ids,
|
45 |
+
weights,
|
46 |
+
bins,
|
47 |
+
ctx.top_k,
|
48 |
+
)
|
49 |
+
|
50 |
+
wgrad = None
|
51 |
+
if ctx.needs_input_grad[3]: # need wgrad
|
52 |
+
x = saved_tensors[-1]
|
53 |
+
wgrad = kernels.scatter_wgrad(
|
54 |
+
x,
|
55 |
+
grad,
|
56 |
+
indices,
|
57 |
+
bin_ids,
|
58 |
+
bins,
|
59 |
+
ctx.top_k,
|
60 |
+
)
|
61 |
+
return dgrad, None, None, wgrad, None, None, None
|
62 |
+
|
63 |
+
|
64 |
+
def scatter(
|
65 |
+
x: torch.Tensor,
|
66 |
+
indices: torch.Tensor,
|
67 |
+
bin_ids: torch.Tensor,
|
68 |
+
weights: torch.Tensor,
|
69 |
+
bins: torch.Tensor,
|
70 |
+
top_k: int,
|
71 |
+
) -> Optional[torch.Tensor]:
|
72 |
+
return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any, Optional, Tuple
|
5 |
+
|
6 |
+
# NOTE: Torch needs to be imported before the custom
|
7 |
+
# extensions. Otherwise libc10.so cannot be found.
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Wrap this in a try-block with better error message and
|
11 |
+
# instructions for building the c++ operations.
|
12 |
+
try:
|
13 |
+
from megablocks._ops import ops # type: ignore
|
14 |
+
except ModuleNotFoundError as e:
|
15 |
+
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
+
|
17 |
+
_BITS_FOR_DTYPE = {
|
18 |
+
torch.int16: 16,
|
19 |
+
torch.int32: 32,
|
20 |
+
torch.int64: 64,
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
# Autograd wrapper for sort kernel.
|
25 |
+
# NOTE: Does not support gradients.
|
26 |
+
class SortOp(torch.autograd.Function):
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
30 |
+
if end_bit is None:
|
31 |
+
end_bit = _BITS_FOR_DTYPE[x.dtype]
|
32 |
+
x_out = torch.empty_like(x)
|
33 |
+
iota_out = torch.empty_like(x)
|
34 |
+
ops.sort(x, end_bit, x_out, iota_out)
|
35 |
+
return (x_out, iota_out)
|
36 |
+
|
37 |
+
|
38 |
+
sort = SortOp.apply
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from absl.testing import parameterized
|
9 |
+
|
10 |
+
from megablocks import ops
|
11 |
+
|
12 |
+
_SORT_TESTS = (
|
13 |
+
(16384, torch.int32, None),
|
14 |
+
(16384, torch.int32, 2),
|
15 |
+
(16384, torch.int32, 128),
|
16 |
+
)
|
17 |
+
|
18 |
+
_BASELINE_SORT_TESTS = ((16384,),)
|
19 |
+
|
20 |
+
|
21 |
+
def numpy_dtype(dtype):
|
22 |
+
types = {
|
23 |
+
torch.int16: np.int16,
|
24 |
+
torch.int32: np.int32,
|
25 |
+
torch.int64: np.int64,
|
26 |
+
}
|
27 |
+
return types[dtype]
|
28 |
+
|
29 |
+
|
30 |
+
def benchmark_function(fn, iterations=10):
|
31 |
+
# Run once to get rid of startup overhead.
|
32 |
+
fn()
|
33 |
+
times = []
|
34 |
+
for _ in range(iterations):
|
35 |
+
start = torch.cuda.Event(enable_timing=True)
|
36 |
+
end = torch.cuda.Event(enable_timing=True)
|
37 |
+
start.record()
|
38 |
+
fn()
|
39 |
+
end.record()
|
40 |
+
torch.cuda.synchronize()
|
41 |
+
times.append(start.elapsed_time(end))
|
42 |
+
times = np.array(times)
|
43 |
+
return times.mean(), times.std(), times.max(), times.min()
|
44 |
+
|
45 |
+
|
46 |
+
def log_benchmark(arguments, mean_t, std_t):
|
47 |
+
print('=' * 60)
|
48 |
+
print('Benchmark Parameters:')
|
49 |
+
for (key, value) in arguments.items():
|
50 |
+
print(f'{key} = {value}')
|
51 |
+
print('Results:')
|
52 |
+
print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
|
53 |
+
print('=' * 60)
|
54 |
+
|
55 |
+
|
56 |
+
class SortBenchmark(parameterized.TestCase):
|
57 |
+
|
58 |
+
@parameterized.parameters(*_SORT_TESTS)
|
59 |
+
def testSort(self, n, dtype, max_val):
|
60 |
+
if max_val is None:
|
61 |
+
max_val = np.iinfo(numpy_dtype(dtype)).max
|
62 |
+
end_bit = int(np.ceil(np.log2(max_val)))
|
63 |
+
x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
|
64 |
+
|
65 |
+
mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
|
66 |
+
arguments = {
|
67 |
+
'n': n,
|
68 |
+
'dtype': dtype,
|
69 |
+
'max_val': max_val,
|
70 |
+
}
|
71 |
+
log_benchmark(arguments, mean_t, std_t)
|
72 |
+
|
73 |
+
@parameterized.parameters(*_BASELINE_SORT_TESTS)
|
74 |
+
def testTorchSort(self, n):
|
75 |
+
x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
|
76 |
+
|
77 |
+
mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
|
78 |
+
arguments = {
|
79 |
+
'n': n,
|
80 |
+
}
|
81 |
+
log_benchmark(arguments, mean_t, std_t)
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
unittest.main()
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def sum(x: torch.Tensor, dim: int = 0):
|
7 |
+
if x.shape[dim] == 1:
|
8 |
+
return x.squeeze(dim=dim)
|
9 |
+
return x.sum(dim=dim)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
# NOTE: Torch needs to be imported before the custom
|
7 |
+
# extensions. Otherwise libc10.so cannot be found.
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Wrap this in a try-block with better error message and
|
11 |
+
# instructions for building the c++ operations.
|
12 |
+
try:
|
13 |
+
from megablocks._ops import ops # type: ignore
|
14 |
+
except ModuleNotFoundError as e:
|
15 |
+
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
+
|
17 |
+
|
18 |
+
# Autograd wrapper for topology kernel.
|
19 |
+
# NOTE: Does not support gradients.
|
20 |
+
class TopologyOp(torch.autograd.Function):
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def forward(
|
24 |
+
ctx: Any,
|
25 |
+
padded_bins: torch.Tensor,
|
26 |
+
block_size: int,
|
27 |
+
output_block_rows: int,
|
28 |
+
output_block_columns: int,
|
29 |
+
):
|
30 |
+
out = torch.empty(
|
31 |
+
output_block_rows * output_block_columns,
|
32 |
+
dtype=torch.int16,
|
33 |
+
device=padded_bins.device,
|
34 |
+
)
|
35 |
+
ops.indices(
|
36 |
+
padded_bins,
|
37 |
+
block_size,
|
38 |
+
output_block_rows,
|
39 |
+
output_block_columns,
|
40 |
+
out,
|
41 |
+
)
|
42 |
+
return out
|
43 |
+
|
44 |
+
|
45 |
+
topology = TopologyOp.apply
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from ._ops import ops
|
7 |
+
|
8 |
+
from megablocks.layers.arguments import Arguments
|
9 |
+
from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
|
10 |
+
from megablocks.layers.glu import SparseGLU
|
11 |
+
from megablocks.layers.mlp import MLP, SparseMLP
|
12 |
+
from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
|
13 |
+
|
14 |
+
# This section contains the direct kernel exports (not inlcuded in the original code)
|
15 |
+
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
16 |
+
"""
|
17 |
+
Compute exclusive cumulative sum along the specified dimension.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
x: Input tensor
|
21 |
+
dim: Dimension along which to compute cumsum
|
22 |
+
out: Output tensor (modified in-place)
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
The output tensor
|
26 |
+
"""
|
27 |
+
result = ops.exclusive_cumsum(x, dim)
|
28 |
+
out.copy_(result)
|
29 |
+
return out
|
30 |
+
|
31 |
+
|
32 |
+
def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
33 |
+
"""
|
34 |
+
Compute inclusive cumulative sum along the specified dimension.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x: Input tensor
|
38 |
+
dim: Dimension along which to compute cumsum
|
39 |
+
out: Output tensor (modified in-place)
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
The output tensor
|
43 |
+
"""
|
44 |
+
result = ops.inclusive_cumsum(x, dim)
|
45 |
+
out.copy_(result)
|
46 |
+
return out
|
47 |
+
|
48 |
+
|
49 |
+
def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
|
50 |
+
"""
|
51 |
+
Compute histogram of input tensor values.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
x: Input tensor
|
55 |
+
num_bins: Number of histogram bins
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Histogram tensor with counts for each bin
|
59 |
+
"""
|
60 |
+
return ops.histogram(x, num_bins)
|
61 |
+
|
62 |
+
|
63 |
+
def indices(
|
64 |
+
padded_bins: torch.Tensor,
|
65 |
+
block_size: int,
|
66 |
+
output_block_rows: int,
|
67 |
+
output_block_columns: int,
|
68 |
+
) -> torch.Tensor:
|
69 |
+
"""
|
70 |
+
Construct indices from padded bins for sparse operations.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
padded_bins: Tensor containing bin boundaries
|
74 |
+
block_size: Size of each block
|
75 |
+
output_block_rows: Number of rows in output blocks
|
76 |
+
output_block_columns: Number of columns in output blocks
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
Tensor containing constructed indices
|
80 |
+
"""
|
81 |
+
return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
|
82 |
+
|
83 |
+
|
84 |
+
def replicate_forward(
|
85 |
+
x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
|
86 |
+
) -> torch.Tensor:
|
87 |
+
"""
|
88 |
+
Forward pass of replicate operation - replicate values according to bin sizes.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
x: Input tensor with values to replicate
|
92 |
+
bins: Tensor containing bin sizes
|
93 |
+
out: Output tensor (modified in-place)
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
The output tensor
|
97 |
+
"""
|
98 |
+
return ops.replicate_forward(x, bins, out)
|
99 |
+
|
100 |
+
|
101 |
+
def replicate_backward(
|
102 |
+
grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
|
103 |
+
) -> torch.Tensor:
|
104 |
+
"""
|
105 |
+
Backward pass of replicate operation - reduce gradients back to bins.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
grad: Gradient tensor to reduce
|
109 |
+
bins: Tensor containing bin sizes
|
110 |
+
out: Output tensor (modified in-place)
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
The output tensor
|
114 |
+
"""
|
115 |
+
return ops.replicate_backward(grad, bins, out)
|
116 |
+
|
117 |
+
|
118 |
+
def sort(
|
119 |
+
x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
|
120 |
+
) -> torch.Tensor:
|
121 |
+
"""
|
122 |
+
Radix sort with index tracking.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
x: Input tensor to sort
|
126 |
+
end_bit: Number of bits to consider in sorting
|
127 |
+
x_out: Output tensor for sorted values
|
128 |
+
iota_out: Output tensor for sorted indices
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
The sorted values tensor
|
132 |
+
"""
|
133 |
+
return ops.sort(x, end_bit, x_out, iota_out)
|
134 |
+
|
135 |
+
|
136 |
+
# Convenience functions for common use cases
|
137 |
+
def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
|
138 |
+
"""
|
139 |
+
Compute cumulative sum with automatic output allocation.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x: Input tensor
|
143 |
+
dim: Dimension along which to compute cumsum (default: last dimension)
|
144 |
+
exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
New tensor containing the cumulative sum
|
148 |
+
"""
|
149 |
+
out = torch.empty_like(x)
|
150 |
+
if exclusive:
|
151 |
+
return exclusive_cumsum(x, dim, out)
|
152 |
+
else:
|
153 |
+
return inclusive_cumsum(x, dim, out)
|
154 |
+
|
155 |
+
|
156 |
+
def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
|
157 |
+
"""
|
158 |
+
Sort tensor and return both sorted values and indices.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
x: Input tensor to sort
|
162 |
+
end_bit: Number of bits to consider in sorting
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
Tuple of (sorted_values, sorted_indices)
|
166 |
+
"""
|
167 |
+
x_out = torch.empty_like(x)
|
168 |
+
iota_out = torch.empty_like(x)
|
169 |
+
sort(x, end_bit, x_out, iota_out)
|
170 |
+
return x_out, iota_out
|
171 |
+
|
172 |
+
|
173 |
+
# Export public API
|
174 |
+
__all__ = [
|
175 |
+
# Direct kernel exports
|
176 |
+
"exclusive_cumsum",
|
177 |
+
"inclusive_cumsum",
|
178 |
+
"histogram",
|
179 |
+
"indices",
|
180 |
+
"replicate_forward",
|
181 |
+
"replicate_backward",
|
182 |
+
"sort",
|
183 |
+
"cumsum",
|
184 |
+
"argsort",
|
185 |
+
# Original exports
|
186 |
+
"Arguments",
|
187 |
+
"ParallelDroplessMLP",
|
188 |
+
"dMoE",
|
189 |
+
"SparseGLU",
|
190 |
+
"MLP",
|
191 |
+
"SparseMLP",
|
192 |
+
"MoE",
|
193 |
+
"ParallelMLP",
|
194 |
+
"get_load_balancing_loss",
|
195 |
+
]
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:12cf047c4bcb5f368f490ada3dafcecd1242a443ff5ded94c09d62548922098c
|
3 |
+
size 11795992
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _megablocks_359242d
|
3 |
+
ops = torch.ops._megablocks_359242d
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_megablocks_359242d::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_version.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""The MegaBlocks Version."""
|
5 |
+
|
6 |
+
__version__ = '0.11.0.dev0'
|