drbh
commited on
Commit
·
b2bfc37
1
Parent(s):
484fde0
fix: bump build and imports
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_a585153_dirty.abi3.so → _megablocks_6756875_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +2 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py +2 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py +2 -2
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py +8 -5
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py +16 -5
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py +2 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py +9 -5
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py +50 -18
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py +2 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py +4 -2
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py +4 -2
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py +14 -14
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +5 -2
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py +1 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_a585153_dirty.abi3.so → _megablocks_6756875_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py +2 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py +2 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py +1 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py +2 -2
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py +8 -5
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py +16 -5
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py +2 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py +9 -5
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py +50 -18
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py +2 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py +4 -2
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py +4 -2
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py +14 -14
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +5 -2
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py +1 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py +1 -1
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_a585153_dirty.abi3.so → _megablocks_6756875_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 10517608
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ad46e9f244afa886c8a104d75e37f93afd2a0ecf83bfc7a414680fa16d8b78f9
|
3 |
size 10517608
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _megablocks_6756875_dirty
|
3 |
+
ops = torch.ops._megablocks_6756875_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_megablocks_6756875_dirty::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py
CHANGED
@@ -10,7 +10,8 @@ import torch
|
|
10 |
# We import the backend operations from the megablocks package as
|
11 |
# grouped_gemm is vendored in megablocks in this repository.
|
12 |
# from ... import _ops as backend
|
13 |
-
from megablocks._ops import ops as backend # type: ignore
|
|
|
14 |
|
15 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
16 |
assert not (trans_a and trans_b)
|
|
|
10 |
# We import the backend operations from the megablocks package as
|
11 |
# grouped_gemm is vendored in megablocks in this repository.
|
12 |
# from ... import _ops as backend
|
13 |
+
# from megablocks._ops import ops as backend # type: ignore
|
14 |
+
from .._ops import ops as backend # type: ignore
|
15 |
|
16 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
17 |
assert not (trans_a and trans_b)
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py
CHANGED
@@ -9,7 +9,8 @@ import torch
|
|
9 |
import torch.distributed as dist
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
-
import megablocks.grouped_gemm_util as grouped_gemm
|
|
|
13 |
|
14 |
# Type annotation for in-place Tensor initialization function.
|
15 |
InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
|
|
|
9 |
import torch.distributed as dist
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
+
# import megablocks.grouped_gemm_util as grouped_gemm
|
13 |
+
from .. import grouped_gemm_util as grouped_gemm
|
14 |
|
15 |
# Type annotation for in-place Tensor initialization function.
|
16 |
InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
|
4 |
import torch
|
5 |
|
6 |
-
from
|
7 |
|
8 |
|
9 |
def dtype(args: Arguments):
|
|
|
3 |
|
4 |
import torch
|
5 |
|
6 |
+
from .arguments import Arguments
|
7 |
|
8 |
|
9 |
def dtype(args: Arguments):
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py
CHANGED
@@ -3,8 +3,8 @@
|
|
3 |
|
4 |
from typing import Union
|
5 |
|
6 |
-
from
|
7 |
-
from
|
8 |
|
9 |
MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
|
10 |
|
|
|
3 |
|
4 |
from typing import Union
|
5 |
|
6 |
+
from . import glu, mlp
|
7 |
+
from .arguments import Arguments
|
8 |
|
9 |
MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
|
10 |
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py
CHANGED
@@ -6,11 +6,14 @@ import stk.ops
|
|
6 |
import torch
|
7 |
from stk import Matrix
|
8 |
|
9 |
-
import megablocks.ops as ops
|
10 |
-
# from megablocks.ops import ops
|
11 |
-
from megablocks.layers import common, dmlp_registry, moe, mpu
|
12 |
-
from megablocks.layers.arguments import Arguments
|
13 |
-
|
|
|
|
|
|
|
14 |
|
15 |
def promote_scalar(x):
|
16 |
return x.view(1) if not len(x.size()) else x
|
|
|
6 |
import torch
|
7 |
from stk import Matrix
|
8 |
|
9 |
+
# import megablocks.ops as ops
|
10 |
+
# # from megablocks.ops import ops
|
11 |
+
# from megablocks.layers import common, dmlp_registry, moe, mpu
|
12 |
+
# from megablocks.layers.arguments import Arguments
|
13 |
+
|
14 |
+
from .. import ops
|
15 |
+
from . import common, dmlp_registry, moe, mpu
|
16 |
+
from .arguments import Arguments
|
17 |
|
18 |
def promote_scalar(x):
|
19 |
return x.view(1) if not len(x.size()) else x
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py
CHANGED
@@ -4,11 +4,22 @@
|
|
4 |
import stk.ops
|
5 |
import torch
|
6 |
|
7 |
-
from megablocks import grouped_gemm_util as gg
|
8 |
-
from megablocks.layers import common, mpu
|
9 |
-
from megablocks.layers.activation_fn import act_fn
|
10 |
-
from megablocks.layers.arguments import Arguments
|
11 |
-
from megablocks.layers.mlp import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
SharedMLP,
|
13 |
SparseMLP,
|
14 |
create_dmoe_expert_weights,
|
|
|
4 |
import stk.ops
|
5 |
import torch
|
6 |
|
7 |
+
# from megablocks import grouped_gemm_util as gg
|
8 |
+
# from megablocks.layers import common, mpu
|
9 |
+
# from megablocks.layers.activation_fn import act_fn
|
10 |
+
# from megablocks.layers.arguments import Arguments
|
11 |
+
# from megablocks.layers.mlp import (
|
12 |
+
# SharedMLP,
|
13 |
+
# SparseMLP,
|
14 |
+
# create_dmoe_expert_weights,
|
15 |
+
# resolve_dtensor,
|
16 |
+
# )
|
17 |
+
|
18 |
+
from .. import grouped_gemm_util as gg
|
19 |
+
from . import common, mpu
|
20 |
+
from .activation_fn import act_fn
|
21 |
+
from .arguments import Arguments
|
22 |
+
from .mlp import (
|
23 |
SharedMLP,
|
24 |
SparseMLP,
|
25 |
create_dmoe_expert_weights,
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py
CHANGED
@@ -6,7 +6,8 @@ import gc
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
-
from megablocks.layers import arguments, dmoe
|
|
|
10 |
|
11 |
_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
|
12 |
|
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
+
# from megablocks.layers import arguments, dmoe
|
10 |
+
from . import arguments, dmoe
|
11 |
|
12 |
_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
|
13 |
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py
CHANGED
@@ -9,11 +9,15 @@ import stk.ops
|
|
9 |
import torch
|
10 |
from packaging import version
|
11 |
|
12 |
-
from megablocks import grouped_gemm_util as gg
|
13 |
-
from megablocks.layers import common, gelu, mpu
|
14 |
-
from megablocks.layers.activation_fn import act_fn
|
15 |
-
from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
|
18 |
class ScaleGradient(torch.autograd.Function):
|
19 |
|
|
|
9 |
import torch
|
10 |
from packaging import version
|
11 |
|
12 |
+
# from megablocks import grouped_gemm_util as gg
|
13 |
+
# from megablocks.layers import common, gelu, mpu
|
14 |
+
# from megablocks.layers.activation_fn import act_fn
|
15 |
+
# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
|
16 |
+
|
17 |
+
from .. import grouped_gemm_util as gg
|
18 |
+
from . import common, gelu, mpu
|
19 |
+
from .activation_fn import act_fn
|
20 |
+
from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
|
21 |
|
22 |
class ScaleGradient(torch.autograd.Function):
|
23 |
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py
CHANGED
@@ -6,10 +6,27 @@ import numpy as np
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
-
import megablocks.ops as ops
|
10 |
-
from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
|
11 |
-
from megablocks.layers.all_to_all import all_to_all
|
12 |
-
from megablocks.layers.arguments import Arguments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
_LOAD_BALANCING_LOSS = []
|
15 |
|
@@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module):
|
|
158 |
# prior? Could we place the `torch.max` operation to return
|
159 |
# 32-bit expert indices?
|
160 |
top_expert = top_expert.int()
|
161 |
-
output = ops.sort(top_expert, self.sort_end_bit)
|
|
|
162 |
assert output is not None
|
163 |
bin_ids, indices = output
|
164 |
|
@@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module):
|
|
168 |
# TODO(tgale): Does the sorted data produce a more favorable
|
169 |
# data distribution for histogram? Or is the op parallelism
|
170 |
# worth more?
|
171 |
-
tokens_per_expert = ops.histogram(top_expert, self.num_experts)
|
|
|
172 |
|
173 |
# Calculate the bin bounds for the sorted tokens.
|
174 |
-
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
|
|
175 |
assert bins is not None
|
176 |
bins = bins.view(1) if not len(bins.size()) else bins
|
177 |
|
@@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module):
|
|
195 |
):
|
196 |
# Route the tokens for MoE computation.
|
197 |
x = x.view(-1, x.shape[-1])
|
198 |
-
output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
|
|
|
199 |
assert output is not None
|
200 |
x = output
|
201 |
|
@@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module):
|
|
204 |
x = self.mlp(x)
|
205 |
|
206 |
# Un-route the data for the MoE output.
|
207 |
-
return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
|
|
|
|
|
208 |
|
209 |
def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
210 |
# x: [sl, bs, hs]
|
@@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module):
|
|
264 |
# If we're sharding the experts along the hidden dimension
|
265 |
# multiple devices own parts of the same sets of experts.
|
266 |
# Replicate the token counts so every device gets the counts.
|
267 |
-
repeated_tokens_per_expert = ops.repeat(
|
|
|
268 |
tokens_per_expert,
|
269 |
(mpu.hidden_sharding_degree(self.args),),
|
270 |
)
|
@@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module):
|
|
285 |
# This view updates the shape of the tensor from [sl, bs, hs] to
|
286 |
# [sl * bs, hs] prior to the permutation.
|
287 |
x = x.view(-1, x.shape[-1])
|
288 |
-
output = ops.gather(x, indices, bin_ids, bins, self.top_k)
|
|
|
289 |
assert output is not None
|
290 |
x = output
|
291 |
|
@@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module):
|
|
317 |
# get all of the tokens assigned to them.
|
318 |
#
|
319 |
# TODO(tgale): Fuse this into the prior, local permutation.
|
320 |
-
x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
|
|
|
321 |
|
322 |
# Start the cross-device permutation asynchronously so we can
|
323 |
# overlap communication with computation.
|
@@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module):
|
|
336 |
# for expert computation we'll do one more local permutation. The
|
337 |
# rest of this torch.no_grad() scope sets up the indices and bins
|
338 |
# for this permutation.
|
339 |
-
replicate_bins = ops.inclusive_cumsum(
|
|
|
340 |
parallel_tokens_per_expert.flatten(),
|
341 |
0,
|
342 |
)
|
@@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module):
|
|
351 |
),
|
352 |
mpu.experts_per_rank(self.args),
|
353 |
)
|
354 |
-
parallel_top_expert = ops.replicate(
|
|
|
355 |
parallel_top_expert.unsqueeze(dim=0),
|
356 |
replicate_bins,
|
357 |
tokens_received,
|
358 |
).flatten()
|
359 |
|
360 |
# TODO(tgale): The sort_end_bit here can be reduced.
|
361 |
-
parallel_bin_ids, parallel_indices = ops.sort(
|
|
|
362 |
parallel_top_expert,
|
363 |
self.sort_end_bit,
|
364 |
)
|
@@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module):
|
|
368 |
dim=0,
|
369 |
dtype=torch.int,
|
370 |
)
|
371 |
-
parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
|
|
|
372 |
parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
|
373 |
|
374 |
# If expert_capacity is set to zero, set the number of tokens
|
@@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module):
|
|
416 |
-1,
|
417 |
self.args.hidden_size,
|
418 |
)
|
419 |
-
x = ops.sum(x.view(shape), dim=0)
|
|
|
420 |
|
421 |
# Un-permute locally to setup for the next series of operations.
|
422 |
-
x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
|
|
|
423 |
return x, tokens_per_expert.flatten()
|
424 |
|
425 |
def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
+
# import megablocks.ops as ops
|
10 |
+
# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
|
11 |
+
# from megablocks.layers.all_to_all import all_to_all
|
12 |
+
# from megablocks.layers.arguments import Arguments
|
13 |
+
|
14 |
+
from ..ops import (
|
15 |
+
sort,
|
16 |
+
histogram,
|
17 |
+
inclusive_cumsum,
|
18 |
+
exclusive_cumsum,
|
19 |
+
binned_gather,
|
20 |
+
binned_scatter,
|
21 |
+
gather,
|
22 |
+
scatter,
|
23 |
+
repeat,
|
24 |
+
replicate,
|
25 |
+
)
|
26 |
+
|
27 |
+
from . import common, mlp, mpu, router, sharedexpert_registry
|
28 |
+
from .arguments import Arguments
|
29 |
+
from .all_to_all import all_to_all
|
30 |
|
31 |
_LOAD_BALANCING_LOSS = []
|
32 |
|
|
|
175 |
# prior? Could we place the `torch.max` operation to return
|
176 |
# 32-bit expert indices?
|
177 |
top_expert = top_expert.int()
|
178 |
+
# output = ops.sort(top_expert, self.sort_end_bit)
|
179 |
+
output = sort(top_expert, self.sort_end_bit)
|
180 |
assert output is not None
|
181 |
bin_ids, indices = output
|
182 |
|
|
|
186 |
# TODO(tgale): Does the sorted data produce a more favorable
|
187 |
# data distribution for histogram? Or is the op parallelism
|
188 |
# worth more?
|
189 |
+
# tokens_per_expert = ops.histogram(top_expert, self.num_experts)
|
190 |
+
tokens_per_expert = histogram(top_expert, self.num_experts)
|
191 |
|
192 |
# Calculate the bin bounds for the sorted tokens.
|
193 |
+
# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
194 |
+
bins = inclusive_cumsum(tokens_per_expert, 0)
|
195 |
assert bins is not None
|
196 |
bins = bins.view(1) if not len(bins.size()) else bins
|
197 |
|
|
|
215 |
):
|
216 |
# Route the tokens for MoE computation.
|
217 |
x = x.view(-1, x.shape[-1])
|
218 |
+
# output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
|
219 |
+
output = binned_gather(x, indices, bins, expert_capacity, top_k)
|
220 |
assert output is not None
|
221 |
x = output
|
222 |
|
|
|
225 |
x = self.mlp(x)
|
226 |
|
227 |
# Un-route the data for the MoE output.
|
228 |
+
# return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
|
229 |
+
return binned_scatter(x, indices, expert_weights, bins, top_k)
|
230 |
+
|
231 |
|
232 |
def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
233 |
# x: [sl, bs, hs]
|
|
|
287 |
# If we're sharding the experts along the hidden dimension
|
288 |
# multiple devices own parts of the same sets of experts.
|
289 |
# Replicate the token counts so every device gets the counts.
|
290 |
+
# repeated_tokens_per_expert = ops.repeat(
|
291 |
+
repeated_tokens_per_expert = repeat(
|
292 |
tokens_per_expert,
|
293 |
(mpu.hidden_sharding_degree(self.args),),
|
294 |
)
|
|
|
309 |
# This view updates the shape of the tensor from [sl, bs, hs] to
|
310 |
# [sl * bs, hs] prior to the permutation.
|
311 |
x = x.view(-1, x.shape[-1])
|
312 |
+
# output = ops.gather(x, indices, bin_ids, bins, self.top_k)
|
313 |
+
output = gather(x, indices, bin_ids, bins, self.top_k)
|
314 |
assert output is not None
|
315 |
x = output
|
316 |
|
|
|
342 |
# get all of the tokens assigned to them.
|
343 |
#
|
344 |
# TODO(tgale): Fuse this into the prior, local permutation.
|
345 |
+
# x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
|
346 |
+
x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
|
347 |
|
348 |
# Start the cross-device permutation asynchronously so we can
|
349 |
# overlap communication with computation.
|
|
|
362 |
# for expert computation we'll do one more local permutation. The
|
363 |
# rest of this torch.no_grad() scope sets up the indices and bins
|
364 |
# for this permutation.
|
365 |
+
# replicate_bins = ops.inclusive_cumsum(
|
366 |
+
replicate_bins = inclusive_cumsum(
|
367 |
parallel_tokens_per_expert.flatten(),
|
368 |
0,
|
369 |
)
|
|
|
378 |
),
|
379 |
mpu.experts_per_rank(self.args),
|
380 |
)
|
381 |
+
# parallel_top_expert = ops.replicate(
|
382 |
+
parallel_top_expert = replicate(
|
383 |
parallel_top_expert.unsqueeze(dim=0),
|
384 |
replicate_bins,
|
385 |
tokens_received,
|
386 |
).flatten()
|
387 |
|
388 |
# TODO(tgale): The sort_end_bit here can be reduced.
|
389 |
+
# parallel_bin_ids, parallel_indices = ops.sort(
|
390 |
+
parallel_bin_ids, parallel_indices = sort(
|
391 |
parallel_top_expert,
|
392 |
self.sort_end_bit,
|
393 |
)
|
|
|
397 |
dim=0,
|
398 |
dtype=torch.int,
|
399 |
)
|
400 |
+
# parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
|
401 |
+
parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
|
402 |
parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
|
403 |
|
404 |
# If expert_capacity is set to zero, set the number of tokens
|
|
|
446 |
-1,
|
447 |
self.args.hidden_size,
|
448 |
)
|
449 |
+
# x = ops.sum(x.view(shape), dim=0)
|
450 |
+
x = x.view(shape).sum(dim=0)
|
451 |
|
452 |
# Un-permute locally to setup for the next series of operations.
|
453 |
+
# x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
|
454 |
+
x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
|
455 |
return x, tokens_per_expert.flatten()
|
456 |
|
457 |
def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py
CHANGED
@@ -6,7 +6,8 @@ from typing import Optional
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
-
from megablocks.layers.arguments import Arguments
|
|
|
10 |
|
11 |
|
12 |
class MoeParam(torch.Tensor):
|
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
+
# from megablocks.layers.arguments import Arguments
|
10 |
+
from .arguments import Arguments
|
11 |
|
12 |
|
13 |
class MoeParam(torch.Tensor):
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py
CHANGED
@@ -4,8 +4,10 @@ from typing import Any
|
|
4 |
|
5 |
import torch
|
6 |
|
7 |
-
from megablocks.layers import common
|
8 |
-
from megablocks.layers.arguments import Arguments
|
|
|
|
|
9 |
|
10 |
_ROUTER_LOGITS = []
|
11 |
|
|
|
4 |
|
5 |
import torch
|
6 |
|
7 |
+
# from megablocks.layers import common
|
8 |
+
# from megablocks.layers.arguments import Arguments
|
9 |
+
from . import common
|
10 |
+
from .arguments import Arguments
|
11 |
|
12 |
_ROUTER_LOGITS = []
|
13 |
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py
CHANGED
@@ -3,8 +3,10 @@
|
|
3 |
|
4 |
from typing import Union
|
5 |
|
6 |
-
from megablocks.layers import glu, mlp
|
7 |
-
from megablocks.layers.arguments import Arguments
|
|
|
|
|
8 |
|
9 |
_REGISTRY = {
|
10 |
'mlp': mlp.SharedMLP,
|
|
|
3 |
|
4 |
from typing import Union
|
5 |
|
6 |
+
# from megablocks.layers import glu, mlp
|
7 |
+
# from megablocks.layers.arguments import Arguments
|
8 |
+
from . import glu, mlp
|
9 |
+
from .arguments import Arguments
|
10 |
|
11 |
_REGISTRY = {
|
12 |
'mlp': mlp.SharedMLP,
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py
CHANGED
@@ -1,20 +1,20 @@
|
|
1 |
# Copyright 2024 Databricks
|
2 |
# SPDX-License-Identifier: Apache-2.0
|
3 |
|
4 |
-
from
|
5 |
-
from
|
6 |
-
from
|
7 |
-
from
|
8 |
-
from
|
9 |
-
from
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from
|
14 |
-
from
|
15 |
-
from
|
16 |
-
from
|
17 |
-
from
|
18 |
|
19 |
__all__ = [
|
20 |
'binned_gather',
|
|
|
1 |
# Copyright 2024 Databricks
|
2 |
# SPDX-License-Identifier: Apache-2.0
|
3 |
|
4 |
+
from .binned_gather import binned_gather
|
5 |
+
from .binned_scatter import binned_scatter
|
6 |
+
from .cumsum import exclusive_cumsum, inclusive_cumsum
|
7 |
+
from .gather import gather
|
8 |
+
from .histogram import histogram
|
9 |
+
from .padded_gather import padded_gather
|
10 |
+
from .padded_scatter import padded_scatter
|
11 |
+
from .repeat import repeat
|
12 |
+
from .replicate import replicate
|
13 |
+
from .round_up import round_up
|
14 |
+
from .scatter import scatter
|
15 |
+
from .sort import sort
|
16 |
+
from .sum import sum
|
17 |
+
from .topology import topology
|
18 |
|
19 |
__all__ = [
|
20 |
'binned_gather',
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py
CHANGED
@@ -4,8 +4,11 @@
|
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
|
7 |
-
from megablocks import benchmark_util
|
8 |
-
from megablocks.layers.all_to_all import all_to_all
|
|
|
|
|
|
|
9 |
|
10 |
_ALL_TO_ALL_BENCHMARK = (
|
11 |
(8, 1024),
|
|
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
|
7 |
+
# from megablocks import benchmark_util
|
8 |
+
# from megablocks.layers.all_to_all import all_to_all
|
9 |
+
|
10 |
+
from .. import benchmark_util
|
11 |
+
from ..layers.all_to_all import all_to_all
|
12 |
|
13 |
_ALL_TO_ALL_BENCHMARK = (
|
14 |
(8, 1024),
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Any
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
-
from
|
9 |
|
10 |
|
11 |
# Autograd wrapper for binned_gather kernel.
|
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
+
from ..backend import kernels
|
9 |
|
10 |
|
11 |
# Autograd wrapper for binned_gather kernel.
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Any
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
-
from
|
9 |
|
10 |
|
11 |
# Autograd wrapper for binned_scatter kernel.
|
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
+
from ..backend import kernels
|
9 |
|
10 |
|
11 |
# Autograd wrapper for binned_scatter kernel.
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py
CHANGED
@@ -11,7 +11,7 @@ import torch
|
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
# import megablocks_ops as ops # type: ignore
|
14 |
-
from
|
15 |
except ModuleNotFoundError as e:
|
16 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
17 |
|
|
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
# import megablocks_ops as ops # type: ignore
|
14 |
+
from .._ops import ops # type: ignore
|
15 |
except ModuleNotFoundError as e:
|
16 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
17 |
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Any
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
-
from
|
9 |
|
10 |
|
11 |
# Autograd wrapper for gather kernel.
|
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
+
from ..backend import kernels
|
9 |
|
10 |
|
11 |
# Autograd wrapper for gather kernel.
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py
CHANGED
@@ -10,7 +10,7 @@ import torch
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
-
from
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
+
from .._ops import ops # type: ignore
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py
CHANGED
@@ -7,7 +7,7 @@ import numpy as np
|
|
7 |
import torch
|
8 |
from absl.testing import parameterized
|
9 |
|
10 |
-
from
|
11 |
|
12 |
_HISTOGRAM_TESTS = (
|
13 |
(16384, torch.int32, 2),
|
|
|
7 |
import torch
|
8 |
from absl.testing import parameterized
|
9 |
|
10 |
+
from .. import ops
|
11 |
|
12 |
_HISTOGRAM_TESTS = (
|
13 |
(16384, torch.int32, 2),
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py
CHANGED
@@ -7,7 +7,7 @@ import stk
|
|
7 |
import torch
|
8 |
from absl.testing import parameterized
|
9 |
|
10 |
-
from
|
11 |
|
12 |
|
13 |
# Calling tensor.t() calls tensor.transpose(0, 1) which calls
|
|
|
7 |
import torch
|
8 |
from absl.testing import parameterized
|
9 |
|
10 |
+
from .. import benchmark_util, ops
|
11 |
|
12 |
|
13 |
# Calling tensor.t() calls tensor.transpose(0, 1) which calls
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Any
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
-
from
|
9 |
|
10 |
|
11 |
# Autograd wrapper for padded_gather kernel.
|
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
+
from ..backend import kernels
|
9 |
|
10 |
|
11 |
# Autograd wrapper for padded_gather kernel.
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Any
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
-
from
|
9 |
|
10 |
|
11 |
# Autograd wrapper for padded_scatter kernel.
|
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
+
from ..backend import kernels
|
9 |
|
10 |
|
11 |
# Autograd wrapper for padded_scatter kernel.
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py
CHANGED
@@ -6,7 +6,7 @@ import unittest
|
|
6 |
import torch
|
7 |
from absl.testing import parameterized
|
8 |
|
9 |
-
from
|
10 |
|
11 |
_PADDED_SCATTER_BENCHMARK = (
|
12 |
# dMoE-Medium, 8-way EMP.
|
|
|
6 |
import torch
|
7 |
from absl.testing import parameterized
|
8 |
|
9 |
+
from .. import benchmark_util, ops
|
10 |
|
11 |
_PADDED_SCATTER_BENCHMARK = (
|
12 |
# dMoE-Medium, 8-way EMP.
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py
CHANGED
@@ -6,7 +6,7 @@ import unittest
|
|
6 |
import torch
|
7 |
from absl.testing import parameterized
|
8 |
|
9 |
-
from
|
10 |
|
11 |
_PERMUTE_TESTS = (
|
12 |
(16384, 768, 2),
|
|
|
6 |
import torch
|
7 |
from absl.testing import parameterized
|
8 |
|
9 |
+
from .. import benchmark_util, ops
|
10 |
|
11 |
_PERMUTE_TESTS = (
|
12 |
(16384, 768, 2),
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py
CHANGED
@@ -10,7 +10,7 @@ import torch
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
-
from
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
+
from .._ops import ops # type: ignore
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py
CHANGED
@@ -6,7 +6,7 @@ from typing import Any, Optional
|
|
6 |
import torch
|
7 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
8 |
|
9 |
-
from
|
10 |
|
11 |
|
12 |
# Autograd wrapper for scatter kernel.
|
|
|
6 |
import torch
|
7 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
8 |
|
9 |
+
from ..backend import kernels
|
10 |
|
11 |
|
12 |
# Autograd wrapper for scatter kernel.
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py
CHANGED
@@ -10,7 +10,7 @@ import torch
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
-
from
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
+
from .._ops import ops # type: ignore
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py
CHANGED
@@ -7,7 +7,7 @@ import numpy as np
|
|
7 |
import torch
|
8 |
from absl.testing import parameterized
|
9 |
|
10 |
-
from
|
11 |
|
12 |
_SORT_TESTS = (
|
13 |
(16384, torch.int32, None),
|
|
|
7 |
import torch
|
8 |
from absl.testing import parameterized
|
9 |
|
10 |
+
from .. import ops
|
11 |
|
12 |
_SORT_TESTS = (
|
13 |
(16384, torch.int32, None),
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py
CHANGED
@@ -10,7 +10,7 @@ import torch
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
-
from
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
+
from .._ops import ops # type: ignore
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_a585153_dirty.abi3.so → _megablocks_6756875_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 11869424
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1419672a07ed370d7107ca54a6b694f234efa8e696644ee4e96c1bf396aff6af
|
3 |
size 11869424
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _megablocks_6756875_dirty
|
3 |
+
ops = torch.ops._megablocks_6756875_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_megablocks_6756875_dirty::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py
CHANGED
@@ -10,7 +10,8 @@ import torch
|
|
10 |
# We import the backend operations from the megablocks package as
|
11 |
# grouped_gemm is vendored in megablocks in this repository.
|
12 |
# from ... import _ops as backend
|
13 |
-
from megablocks._ops import ops as backend # type: ignore
|
|
|
14 |
|
15 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
16 |
assert not (trans_a and trans_b)
|
|
|
10 |
# We import the backend operations from the megablocks package as
|
11 |
# grouped_gemm is vendored in megablocks in this repository.
|
12 |
# from ... import _ops as backend
|
13 |
+
# from megablocks._ops import ops as backend # type: ignore
|
14 |
+
from .._ops import ops as backend # type: ignore
|
15 |
|
16 |
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
17 |
assert not (trans_a and trans_b)
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py
CHANGED
@@ -9,7 +9,8 @@ import torch
|
|
9 |
import torch.distributed as dist
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
-
import megablocks.grouped_gemm_util as grouped_gemm
|
|
|
13 |
|
14 |
# Type annotation for in-place Tensor initialization function.
|
15 |
InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
|
|
|
9 |
import torch.distributed as dist
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
+
# import megablocks.grouped_gemm_util as grouped_gemm
|
13 |
+
from .. import grouped_gemm_util as grouped_gemm
|
14 |
|
15 |
# Type annotation for in-place Tensor initialization function.
|
16 |
InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
|
4 |
import torch
|
5 |
|
6 |
-
from
|
7 |
|
8 |
|
9 |
def dtype(args: Arguments):
|
|
|
3 |
|
4 |
import torch
|
5 |
|
6 |
+
from .arguments import Arguments
|
7 |
|
8 |
|
9 |
def dtype(args: Arguments):
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py
CHANGED
@@ -3,8 +3,8 @@
|
|
3 |
|
4 |
from typing import Union
|
5 |
|
6 |
-
from
|
7 |
-
from
|
8 |
|
9 |
MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
|
10 |
|
|
|
3 |
|
4 |
from typing import Union
|
5 |
|
6 |
+
from . import glu, mlp
|
7 |
+
from .arguments import Arguments
|
8 |
|
9 |
MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
|
10 |
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py
CHANGED
@@ -6,11 +6,14 @@ import stk.ops
|
|
6 |
import torch
|
7 |
from stk import Matrix
|
8 |
|
9 |
-
import megablocks.ops as ops
|
10 |
-
# from megablocks.ops import ops
|
11 |
-
from megablocks.layers import common, dmlp_registry, moe, mpu
|
12 |
-
from megablocks.layers.arguments import Arguments
|
13 |
-
|
|
|
|
|
|
|
14 |
|
15 |
def promote_scalar(x):
|
16 |
return x.view(1) if not len(x.size()) else x
|
|
|
6 |
import torch
|
7 |
from stk import Matrix
|
8 |
|
9 |
+
# import megablocks.ops as ops
|
10 |
+
# # from megablocks.ops import ops
|
11 |
+
# from megablocks.layers import common, dmlp_registry, moe, mpu
|
12 |
+
# from megablocks.layers.arguments import Arguments
|
13 |
+
|
14 |
+
from .. import ops
|
15 |
+
from . import common, dmlp_registry, moe, mpu
|
16 |
+
from .arguments import Arguments
|
17 |
|
18 |
def promote_scalar(x):
|
19 |
return x.view(1) if not len(x.size()) else x
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py
CHANGED
@@ -4,11 +4,22 @@
|
|
4 |
import stk.ops
|
5 |
import torch
|
6 |
|
7 |
-
from megablocks import grouped_gemm_util as gg
|
8 |
-
from megablocks.layers import common, mpu
|
9 |
-
from megablocks.layers.activation_fn import act_fn
|
10 |
-
from megablocks.layers.arguments import Arguments
|
11 |
-
from megablocks.layers.mlp import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
SharedMLP,
|
13 |
SparseMLP,
|
14 |
create_dmoe_expert_weights,
|
|
|
4 |
import stk.ops
|
5 |
import torch
|
6 |
|
7 |
+
# from megablocks import grouped_gemm_util as gg
|
8 |
+
# from megablocks.layers import common, mpu
|
9 |
+
# from megablocks.layers.activation_fn import act_fn
|
10 |
+
# from megablocks.layers.arguments import Arguments
|
11 |
+
# from megablocks.layers.mlp import (
|
12 |
+
# SharedMLP,
|
13 |
+
# SparseMLP,
|
14 |
+
# create_dmoe_expert_weights,
|
15 |
+
# resolve_dtensor,
|
16 |
+
# )
|
17 |
+
|
18 |
+
from .. import grouped_gemm_util as gg
|
19 |
+
from . import common, mpu
|
20 |
+
from .activation_fn import act_fn
|
21 |
+
from .arguments import Arguments
|
22 |
+
from .mlp import (
|
23 |
SharedMLP,
|
24 |
SparseMLP,
|
25 |
create_dmoe_expert_weights,
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py
CHANGED
@@ -6,7 +6,8 @@ import gc
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
-
from megablocks.layers import arguments, dmoe
|
|
|
10 |
|
11 |
_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
|
12 |
|
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
+
# from megablocks.layers import arguments, dmoe
|
10 |
+
from . import arguments, dmoe
|
11 |
|
12 |
_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
|
13 |
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py
CHANGED
@@ -9,11 +9,15 @@ import stk.ops
|
|
9 |
import torch
|
10 |
from packaging import version
|
11 |
|
12 |
-
from megablocks import grouped_gemm_util as gg
|
13 |
-
from megablocks.layers import common, gelu, mpu
|
14 |
-
from megablocks.layers.activation_fn import act_fn
|
15 |
-
from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
|
18 |
class ScaleGradient(torch.autograd.Function):
|
19 |
|
|
|
9 |
import torch
|
10 |
from packaging import version
|
11 |
|
12 |
+
# from megablocks import grouped_gemm_util as gg
|
13 |
+
# from megablocks.layers import common, gelu, mpu
|
14 |
+
# from megablocks.layers.activation_fn import act_fn
|
15 |
+
# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
|
16 |
+
|
17 |
+
from .. import grouped_gemm_util as gg
|
18 |
+
from . import common, gelu, mpu
|
19 |
+
from .activation_fn import act_fn
|
20 |
+
from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
|
21 |
|
22 |
class ScaleGradient(torch.autograd.Function):
|
23 |
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py
CHANGED
@@ -6,10 +6,27 @@ import numpy as np
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
-
import megablocks.ops as ops
|
10 |
-
from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
|
11 |
-
from megablocks.layers.all_to_all import all_to_all
|
12 |
-
from megablocks.layers.arguments import Arguments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
_LOAD_BALANCING_LOSS = []
|
15 |
|
@@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module):
|
|
158 |
# prior? Could we place the `torch.max` operation to return
|
159 |
# 32-bit expert indices?
|
160 |
top_expert = top_expert.int()
|
161 |
-
output = ops.sort(top_expert, self.sort_end_bit)
|
|
|
162 |
assert output is not None
|
163 |
bin_ids, indices = output
|
164 |
|
@@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module):
|
|
168 |
# TODO(tgale): Does the sorted data produce a more favorable
|
169 |
# data distribution for histogram? Or is the op parallelism
|
170 |
# worth more?
|
171 |
-
tokens_per_expert = ops.histogram(top_expert, self.num_experts)
|
|
|
172 |
|
173 |
# Calculate the bin bounds for the sorted tokens.
|
174 |
-
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
|
|
175 |
assert bins is not None
|
176 |
bins = bins.view(1) if not len(bins.size()) else bins
|
177 |
|
@@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module):
|
|
195 |
):
|
196 |
# Route the tokens for MoE computation.
|
197 |
x = x.view(-1, x.shape[-1])
|
198 |
-
output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
|
|
|
199 |
assert output is not None
|
200 |
x = output
|
201 |
|
@@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module):
|
|
204 |
x = self.mlp(x)
|
205 |
|
206 |
# Un-route the data for the MoE output.
|
207 |
-
return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
|
|
|
|
|
208 |
|
209 |
def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
210 |
# x: [sl, bs, hs]
|
@@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module):
|
|
264 |
# If we're sharding the experts along the hidden dimension
|
265 |
# multiple devices own parts of the same sets of experts.
|
266 |
# Replicate the token counts so every device gets the counts.
|
267 |
-
repeated_tokens_per_expert = ops.repeat(
|
|
|
268 |
tokens_per_expert,
|
269 |
(mpu.hidden_sharding_degree(self.args),),
|
270 |
)
|
@@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module):
|
|
285 |
# This view updates the shape of the tensor from [sl, bs, hs] to
|
286 |
# [sl * bs, hs] prior to the permutation.
|
287 |
x = x.view(-1, x.shape[-1])
|
288 |
-
output = ops.gather(x, indices, bin_ids, bins, self.top_k)
|
|
|
289 |
assert output is not None
|
290 |
x = output
|
291 |
|
@@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module):
|
|
317 |
# get all of the tokens assigned to them.
|
318 |
#
|
319 |
# TODO(tgale): Fuse this into the prior, local permutation.
|
320 |
-
x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
|
|
|
321 |
|
322 |
# Start the cross-device permutation asynchronously so we can
|
323 |
# overlap communication with computation.
|
@@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module):
|
|
336 |
# for expert computation we'll do one more local permutation. The
|
337 |
# rest of this torch.no_grad() scope sets up the indices and bins
|
338 |
# for this permutation.
|
339 |
-
replicate_bins = ops.inclusive_cumsum(
|
|
|
340 |
parallel_tokens_per_expert.flatten(),
|
341 |
0,
|
342 |
)
|
@@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module):
|
|
351 |
),
|
352 |
mpu.experts_per_rank(self.args),
|
353 |
)
|
354 |
-
parallel_top_expert = ops.replicate(
|
|
|
355 |
parallel_top_expert.unsqueeze(dim=0),
|
356 |
replicate_bins,
|
357 |
tokens_received,
|
358 |
).flatten()
|
359 |
|
360 |
# TODO(tgale): The sort_end_bit here can be reduced.
|
361 |
-
parallel_bin_ids, parallel_indices = ops.sort(
|
|
|
362 |
parallel_top_expert,
|
363 |
self.sort_end_bit,
|
364 |
)
|
@@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module):
|
|
368 |
dim=0,
|
369 |
dtype=torch.int,
|
370 |
)
|
371 |
-
parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
|
|
|
372 |
parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
|
373 |
|
374 |
# If expert_capacity is set to zero, set the number of tokens
|
@@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module):
|
|
416 |
-1,
|
417 |
self.args.hidden_size,
|
418 |
)
|
419 |
-
x = ops.sum(x.view(shape), dim=0)
|
|
|
420 |
|
421 |
# Un-permute locally to setup for the next series of operations.
|
422 |
-
x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
|
|
|
423 |
return x, tokens_per_expert.flatten()
|
424 |
|
425 |
def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
+
# import megablocks.ops as ops
|
10 |
+
# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
|
11 |
+
# from megablocks.layers.all_to_all import all_to_all
|
12 |
+
# from megablocks.layers.arguments import Arguments
|
13 |
+
|
14 |
+
from ..ops import (
|
15 |
+
sort,
|
16 |
+
histogram,
|
17 |
+
inclusive_cumsum,
|
18 |
+
exclusive_cumsum,
|
19 |
+
binned_gather,
|
20 |
+
binned_scatter,
|
21 |
+
gather,
|
22 |
+
scatter,
|
23 |
+
repeat,
|
24 |
+
replicate,
|
25 |
+
)
|
26 |
+
|
27 |
+
from . import common, mlp, mpu, router, sharedexpert_registry
|
28 |
+
from .arguments import Arguments
|
29 |
+
from .all_to_all import all_to_all
|
30 |
|
31 |
_LOAD_BALANCING_LOSS = []
|
32 |
|
|
|
175 |
# prior? Could we place the `torch.max` operation to return
|
176 |
# 32-bit expert indices?
|
177 |
top_expert = top_expert.int()
|
178 |
+
# output = ops.sort(top_expert, self.sort_end_bit)
|
179 |
+
output = sort(top_expert, self.sort_end_bit)
|
180 |
assert output is not None
|
181 |
bin_ids, indices = output
|
182 |
|
|
|
186 |
# TODO(tgale): Does the sorted data produce a more favorable
|
187 |
# data distribution for histogram? Or is the op parallelism
|
188 |
# worth more?
|
189 |
+
# tokens_per_expert = ops.histogram(top_expert, self.num_experts)
|
190 |
+
tokens_per_expert = histogram(top_expert, self.num_experts)
|
191 |
|
192 |
# Calculate the bin bounds for the sorted tokens.
|
193 |
+
# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
194 |
+
bins = inclusive_cumsum(tokens_per_expert, 0)
|
195 |
assert bins is not None
|
196 |
bins = bins.view(1) if not len(bins.size()) else bins
|
197 |
|
|
|
215 |
):
|
216 |
# Route the tokens for MoE computation.
|
217 |
x = x.view(-1, x.shape[-1])
|
218 |
+
# output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
|
219 |
+
output = binned_gather(x, indices, bins, expert_capacity, top_k)
|
220 |
assert output is not None
|
221 |
x = output
|
222 |
|
|
|
225 |
x = self.mlp(x)
|
226 |
|
227 |
# Un-route the data for the MoE output.
|
228 |
+
# return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
|
229 |
+
return binned_scatter(x, indices, expert_weights, bins, top_k)
|
230 |
+
|
231 |
|
232 |
def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
233 |
# x: [sl, bs, hs]
|
|
|
287 |
# If we're sharding the experts along the hidden dimension
|
288 |
# multiple devices own parts of the same sets of experts.
|
289 |
# Replicate the token counts so every device gets the counts.
|
290 |
+
# repeated_tokens_per_expert = ops.repeat(
|
291 |
+
repeated_tokens_per_expert = repeat(
|
292 |
tokens_per_expert,
|
293 |
(mpu.hidden_sharding_degree(self.args),),
|
294 |
)
|
|
|
309 |
# This view updates the shape of the tensor from [sl, bs, hs] to
|
310 |
# [sl * bs, hs] prior to the permutation.
|
311 |
x = x.view(-1, x.shape[-1])
|
312 |
+
# output = ops.gather(x, indices, bin_ids, bins, self.top_k)
|
313 |
+
output = gather(x, indices, bin_ids, bins, self.top_k)
|
314 |
assert output is not None
|
315 |
x = output
|
316 |
|
|
|
342 |
# get all of the tokens assigned to them.
|
343 |
#
|
344 |
# TODO(tgale): Fuse this into the prior, local permutation.
|
345 |
+
# x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
|
346 |
+
x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
|
347 |
|
348 |
# Start the cross-device permutation asynchronously so we can
|
349 |
# overlap communication with computation.
|
|
|
362 |
# for expert computation we'll do one more local permutation. The
|
363 |
# rest of this torch.no_grad() scope sets up the indices and bins
|
364 |
# for this permutation.
|
365 |
+
# replicate_bins = ops.inclusive_cumsum(
|
366 |
+
replicate_bins = inclusive_cumsum(
|
367 |
parallel_tokens_per_expert.flatten(),
|
368 |
0,
|
369 |
)
|
|
|
378 |
),
|
379 |
mpu.experts_per_rank(self.args),
|
380 |
)
|
381 |
+
# parallel_top_expert = ops.replicate(
|
382 |
+
parallel_top_expert = replicate(
|
383 |
parallel_top_expert.unsqueeze(dim=0),
|
384 |
replicate_bins,
|
385 |
tokens_received,
|
386 |
).flatten()
|
387 |
|
388 |
# TODO(tgale): The sort_end_bit here can be reduced.
|
389 |
+
# parallel_bin_ids, parallel_indices = ops.sort(
|
390 |
+
parallel_bin_ids, parallel_indices = sort(
|
391 |
parallel_top_expert,
|
392 |
self.sort_end_bit,
|
393 |
)
|
|
|
397 |
dim=0,
|
398 |
dtype=torch.int,
|
399 |
)
|
400 |
+
# parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
|
401 |
+
parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
|
402 |
parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
|
403 |
|
404 |
# If expert_capacity is set to zero, set the number of tokens
|
|
|
446 |
-1,
|
447 |
self.args.hidden_size,
|
448 |
)
|
449 |
+
# x = ops.sum(x.view(shape), dim=0)
|
450 |
+
x = x.view(shape).sum(dim=0)
|
451 |
|
452 |
# Un-permute locally to setup for the next series of operations.
|
453 |
+
# x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
|
454 |
+
x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
|
455 |
return x, tokens_per_expert.flatten()
|
456 |
|
457 |
def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py
CHANGED
@@ -6,7 +6,8 @@ from typing import Optional
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
-
from megablocks.layers.arguments import Arguments
|
|
|
10 |
|
11 |
|
12 |
class MoeParam(torch.Tensor):
|
|
|
6 |
import torch
|
7 |
import torch.distributed as dist
|
8 |
|
9 |
+
# from megablocks.layers.arguments import Arguments
|
10 |
+
from .arguments import Arguments
|
11 |
|
12 |
|
13 |
class MoeParam(torch.Tensor):
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py
CHANGED
@@ -4,8 +4,10 @@ from typing import Any
|
|
4 |
|
5 |
import torch
|
6 |
|
7 |
-
from megablocks.layers import common
|
8 |
-
from megablocks.layers.arguments import Arguments
|
|
|
|
|
9 |
|
10 |
_ROUTER_LOGITS = []
|
11 |
|
|
|
4 |
|
5 |
import torch
|
6 |
|
7 |
+
# from megablocks.layers import common
|
8 |
+
# from megablocks.layers.arguments import Arguments
|
9 |
+
from . import common
|
10 |
+
from .arguments import Arguments
|
11 |
|
12 |
_ROUTER_LOGITS = []
|
13 |
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py
CHANGED
@@ -3,8 +3,10 @@
|
|
3 |
|
4 |
from typing import Union
|
5 |
|
6 |
-
from megablocks.layers import glu, mlp
|
7 |
-
from megablocks.layers.arguments import Arguments
|
|
|
|
|
8 |
|
9 |
_REGISTRY = {
|
10 |
'mlp': mlp.SharedMLP,
|
|
|
3 |
|
4 |
from typing import Union
|
5 |
|
6 |
+
# from megablocks.layers import glu, mlp
|
7 |
+
# from megablocks.layers.arguments import Arguments
|
8 |
+
from . import glu, mlp
|
9 |
+
from .arguments import Arguments
|
10 |
|
11 |
_REGISTRY = {
|
12 |
'mlp': mlp.SharedMLP,
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py
CHANGED
@@ -1,20 +1,20 @@
|
|
1 |
# Copyright 2024 Databricks
|
2 |
# SPDX-License-Identifier: Apache-2.0
|
3 |
|
4 |
-
from
|
5 |
-
from
|
6 |
-
from
|
7 |
-
from
|
8 |
-
from
|
9 |
-
from
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from
|
14 |
-
from
|
15 |
-
from
|
16 |
-
from
|
17 |
-
from
|
18 |
|
19 |
__all__ = [
|
20 |
'binned_gather',
|
|
|
1 |
# Copyright 2024 Databricks
|
2 |
# SPDX-License-Identifier: Apache-2.0
|
3 |
|
4 |
+
from .binned_gather import binned_gather
|
5 |
+
from .binned_scatter import binned_scatter
|
6 |
+
from .cumsum import exclusive_cumsum, inclusive_cumsum
|
7 |
+
from .gather import gather
|
8 |
+
from .histogram import histogram
|
9 |
+
from .padded_gather import padded_gather
|
10 |
+
from .padded_scatter import padded_scatter
|
11 |
+
from .repeat import repeat
|
12 |
+
from .replicate import replicate
|
13 |
+
from .round_up import round_up
|
14 |
+
from .scatter import scatter
|
15 |
+
from .sort import sort
|
16 |
+
from .sum import sum
|
17 |
+
from .topology import topology
|
18 |
|
19 |
__all__ = [
|
20 |
'binned_gather',
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py
CHANGED
@@ -4,8 +4,11 @@
|
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
|
7 |
-
from megablocks import benchmark_util
|
8 |
-
from megablocks.layers.all_to_all import all_to_all
|
|
|
|
|
|
|
9 |
|
10 |
_ALL_TO_ALL_BENCHMARK = (
|
11 |
(8, 1024),
|
|
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
|
7 |
+
# from megablocks import benchmark_util
|
8 |
+
# from megablocks.layers.all_to_all import all_to_all
|
9 |
+
|
10 |
+
from .. import benchmark_util
|
11 |
+
from ..layers.all_to_all import all_to_all
|
12 |
|
13 |
_ALL_TO_ALL_BENCHMARK = (
|
14 |
(8, 1024),
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Any
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
-
from
|
9 |
|
10 |
|
11 |
# Autograd wrapper for binned_gather kernel.
|
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
+
from ..backend import kernels
|
9 |
|
10 |
|
11 |
# Autograd wrapper for binned_gather kernel.
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Any
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
-
from
|
9 |
|
10 |
|
11 |
# Autograd wrapper for binned_scatter kernel.
|
|
|
5 |
import torch
|
6 |
from stk.backend.autocast import custom_bwd, custom_fwd
|
7 |
|
8 |
+
from ..backend import kernels
|
9 |
|
10 |
|
11 |
# Autograd wrapper for binned_scatter kernel.
|