kernel
drbh commited on
Commit
b2bfc37
·
1 Parent(s): 484fde0

fix: bump build and imports

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_a585153_dirty.abi3.so → _megablocks_6756875_dirty.abi3.so} +1 -1
  2. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  3. build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +2 -1
  4. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py +2 -1
  5. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py +1 -1
  6. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py +2 -2
  7. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py +8 -5
  8. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py +16 -5
  9. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py +2 -1
  10. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py +9 -5
  11. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py +50 -18
  12. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py +2 -1
  13. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py +4 -2
  14. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py +4 -2
  15. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py +14 -14
  16. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +5 -2
  17. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +1 -1
  18. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +1 -1
  19. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py +1 -1
  20. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +1 -1
  21. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py +1 -1
  22. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +1 -1
  23. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +1 -1
  24. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +1 -1
  25. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +1 -1
  26. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +1 -1
  27. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +1 -1
  28. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py +1 -1
  29. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +1 -1
  30. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py +1 -1
  31. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +1 -1
  32. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py +1 -1
  33. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_a585153_dirty.abi3.so → _megablocks_6756875_dirty.abi3.so} +1 -1
  34. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  35. build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py +2 -1
  36. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py +2 -1
  37. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py +1 -1
  38. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py +2 -2
  39. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py +8 -5
  40. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py +16 -5
  41. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py +2 -1
  42. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py +9 -5
  43. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py +50 -18
  44. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py +2 -1
  45. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py +4 -2
  46. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py +4 -2
  47. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py +14 -14
  48. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +5 -2
  49. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py +1 -1
  50. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py +1 -1
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_a585153_dirty.abi3.so → _megablocks_6756875_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:44462d45f75616c369c2421fe41d53cd1d1dc365f1d2545d870e2db999e67e38
3
  size 10517608
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad46e9f244afa886c8a104d75e37f93afd2a0ecf83bfc7a414680fa16d8b78f9
3
  size 10517608
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_a585153_dirty
3
- ops = torch.ops._megablocks_a585153_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_a585153_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_6756875_dirty
3
+ ops = torch.ops._megablocks_6756875_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_6756875_dirty::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py CHANGED
@@ -10,7 +10,8 @@ import torch
10
  # We import the backend operations from the megablocks package as
11
  # grouped_gemm is vendored in megablocks in this repository.
12
  # from ... import _ops as backend
13
- from megablocks._ops import ops as backend # type: ignore
 
14
 
15
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
16
  assert not (trans_a and trans_b)
 
10
  # We import the backend operations from the megablocks package as
11
  # grouped_gemm is vendored in megablocks in this repository.
12
  # from ... import _ops as backend
13
+ # from megablocks._ops import ops as backend # type: ignore
14
+ from .._ops import ops as backend # type: ignore
15
 
16
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
  assert not (trans_a and trans_b)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py CHANGED
@@ -9,7 +9,8 @@ import torch
9
  import torch.distributed as dist
10
  import torch.nn.functional as F
11
 
12
- import megablocks.grouped_gemm_util as grouped_gemm
 
13
 
14
  # Type annotation for in-place Tensor initialization function.
15
  InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
 
9
  import torch.distributed as dist
10
  import torch.nn.functional as F
11
 
12
+ # import megablocks.grouped_gemm_util as grouped_gemm
13
+ from .. import grouped_gemm_util as grouped_gemm
14
 
15
  # Type annotation for in-place Tensor initialization function.
16
  InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py CHANGED
@@ -3,7 +3,7 @@
3
 
4
  import torch
5
 
6
- from megablocks.layers.arguments import Arguments
7
 
8
 
9
  def dtype(args: Arguments):
 
3
 
4
  import torch
5
 
6
+ from .arguments import Arguments
7
 
8
 
9
  def dtype(args: Arguments):
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py CHANGED
@@ -3,8 +3,8 @@
3
 
4
  from typing import Union
5
 
6
- from megablocks.layers import glu, mlp
7
- from megablocks.layers.arguments import Arguments
8
 
9
  MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
10
 
 
3
 
4
  from typing import Union
5
 
6
+ from . import glu, mlp
7
+ from .arguments import Arguments
8
 
9
  MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
10
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py CHANGED
@@ -6,11 +6,14 @@ import stk.ops
6
  import torch
7
  from stk import Matrix
8
 
9
- import megablocks.ops as ops
10
- # from megablocks.ops import ops
11
- from megablocks.layers import common, dmlp_registry, moe, mpu
12
- from megablocks.layers.arguments import Arguments
13
-
 
 
 
14
 
15
  def promote_scalar(x):
16
  return x.view(1) if not len(x.size()) else x
 
6
  import torch
7
  from stk import Matrix
8
 
9
+ # import megablocks.ops as ops
10
+ # # from megablocks.ops import ops
11
+ # from megablocks.layers import common, dmlp_registry, moe, mpu
12
+ # from megablocks.layers.arguments import Arguments
13
+
14
+ from .. import ops
15
+ from . import common, dmlp_registry, moe, mpu
16
+ from .arguments import Arguments
17
 
18
  def promote_scalar(x):
19
  return x.view(1) if not len(x.size()) else x
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py CHANGED
@@ -4,11 +4,22 @@
4
  import stk.ops
5
  import torch
6
 
7
- from megablocks import grouped_gemm_util as gg
8
- from megablocks.layers import common, mpu
9
- from megablocks.layers.activation_fn import act_fn
10
- from megablocks.layers.arguments import Arguments
11
- from megablocks.layers.mlp import (
 
 
 
 
 
 
 
 
 
 
 
12
  SharedMLP,
13
  SparseMLP,
14
  create_dmoe_expert_weights,
 
4
  import stk.ops
5
  import torch
6
 
7
+ # from megablocks import grouped_gemm_util as gg
8
+ # from megablocks.layers import common, mpu
9
+ # from megablocks.layers.activation_fn import act_fn
10
+ # from megablocks.layers.arguments import Arguments
11
+ # from megablocks.layers.mlp import (
12
+ # SharedMLP,
13
+ # SparseMLP,
14
+ # create_dmoe_expert_weights,
15
+ # resolve_dtensor,
16
+ # )
17
+
18
+ from .. import grouped_gemm_util as gg
19
+ from . import common, mpu
20
+ from .activation_fn import act_fn
21
+ from .arguments import Arguments
22
+ from .mlp import (
23
  SharedMLP,
24
  SparseMLP,
25
  create_dmoe_expert_weights,
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py CHANGED
@@ -6,7 +6,8 @@ import gc
6
  import torch
7
  import torch.distributed as dist
8
 
9
- from megablocks.layers import arguments, dmoe
 
10
 
11
  _TESTS = ((8, 2048, 4096, 4096, 32, 4),)
12
 
 
6
  import torch
7
  import torch.distributed as dist
8
 
9
+ # from megablocks.layers import arguments, dmoe
10
+ from . import arguments, dmoe
11
 
12
  _TESTS = ((8, 2048, 4096, 4096, 32, 4),)
13
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py CHANGED
@@ -9,11 +9,15 @@ import stk.ops
9
  import torch
10
  from packaging import version
11
 
12
- from megablocks import grouped_gemm_util as gg
13
- from megablocks.layers import common, gelu, mpu
14
- from megablocks.layers.activation_fn import act_fn
15
- from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
16
-
 
 
 
 
17
 
18
  class ScaleGradient(torch.autograd.Function):
19
 
 
9
  import torch
10
  from packaging import version
11
 
12
+ # from megablocks import grouped_gemm_util as gg
13
+ # from megablocks.layers import common, gelu, mpu
14
+ # from megablocks.layers.activation_fn import act_fn
15
+ # from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
16
+
17
+ from .. import grouped_gemm_util as gg
18
+ from . import common, gelu, mpu
19
+ from .activation_fn import act_fn
20
+ from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
21
 
22
  class ScaleGradient(torch.autograd.Function):
23
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py CHANGED
@@ -6,10 +6,27 @@ import numpy as np
6
  import torch
7
  import torch.distributed as dist
8
 
9
- import megablocks.ops as ops
10
- from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
11
- from megablocks.layers.all_to_all import all_to_all
12
- from megablocks.layers.arguments import Arguments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  _LOAD_BALANCING_LOSS = []
15
 
@@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module):
158
  # prior? Could we place the `torch.max` operation to return
159
  # 32-bit expert indices?
160
  top_expert = top_expert.int()
161
- output = ops.sort(top_expert, self.sort_end_bit)
 
162
  assert output is not None
163
  bin_ids, indices = output
164
 
@@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module):
168
  # TODO(tgale): Does the sorted data produce a more favorable
169
  # data distribution for histogram? Or is the op parallelism
170
  # worth more?
171
- tokens_per_expert = ops.histogram(top_expert, self.num_experts)
 
172
 
173
  # Calculate the bin bounds for the sorted tokens.
174
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
 
175
  assert bins is not None
176
  bins = bins.view(1) if not len(bins.size()) else bins
177
 
@@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module):
195
  ):
196
  # Route the tokens for MoE computation.
197
  x = x.view(-1, x.shape[-1])
198
- output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
 
199
  assert output is not None
200
  x = output
201
 
@@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module):
204
  x = self.mlp(x)
205
 
206
  # Un-route the data for the MoE output.
207
- return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
 
 
208
 
209
  def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
210
  # x: [sl, bs, hs]
@@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module):
264
  # If we're sharding the experts along the hidden dimension
265
  # multiple devices own parts of the same sets of experts.
266
  # Replicate the token counts so every device gets the counts.
267
- repeated_tokens_per_expert = ops.repeat(
 
268
  tokens_per_expert,
269
  (mpu.hidden_sharding_degree(self.args),),
270
  )
@@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module):
285
  # This view updates the shape of the tensor from [sl, bs, hs] to
286
  # [sl * bs, hs] prior to the permutation.
287
  x = x.view(-1, x.shape[-1])
288
- output = ops.gather(x, indices, bin_ids, bins, self.top_k)
 
289
  assert output is not None
290
  x = output
291
 
@@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module):
317
  # get all of the tokens assigned to them.
318
  #
319
  # TODO(tgale): Fuse this into the prior, local permutation.
320
- x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
 
321
 
322
  # Start the cross-device permutation asynchronously so we can
323
  # overlap communication with computation.
@@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module):
336
  # for expert computation we'll do one more local permutation. The
337
  # rest of this torch.no_grad() scope sets up the indices and bins
338
  # for this permutation.
339
- replicate_bins = ops.inclusive_cumsum(
 
340
  parallel_tokens_per_expert.flatten(),
341
  0,
342
  )
@@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module):
351
  ),
352
  mpu.experts_per_rank(self.args),
353
  )
354
- parallel_top_expert = ops.replicate(
 
355
  parallel_top_expert.unsqueeze(dim=0),
356
  replicate_bins,
357
  tokens_received,
358
  ).flatten()
359
 
360
  # TODO(tgale): The sort_end_bit here can be reduced.
361
- parallel_bin_ids, parallel_indices = ops.sort(
 
362
  parallel_top_expert,
363
  self.sort_end_bit,
364
  )
@@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module):
368
  dim=0,
369
  dtype=torch.int,
370
  )
371
- parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
372
  parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
373
 
374
  # If expert_capacity is set to zero, set the number of tokens
@@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module):
416
  -1,
417
  self.args.hidden_size,
418
  )
419
- x = ops.sum(x.view(shape), dim=0)
 
420
 
421
  # Un-permute locally to setup for the next series of operations.
422
- x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
 
423
  return x, tokens_per_expert.flatten()
424
 
425
  def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
 
6
  import torch
7
  import torch.distributed as dist
8
 
9
+ # import megablocks.ops as ops
10
+ # from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
11
+ # from megablocks.layers.all_to_all import all_to_all
12
+ # from megablocks.layers.arguments import Arguments
13
+
14
+ from ..ops import (
15
+ sort,
16
+ histogram,
17
+ inclusive_cumsum,
18
+ exclusive_cumsum,
19
+ binned_gather,
20
+ binned_scatter,
21
+ gather,
22
+ scatter,
23
+ repeat,
24
+ replicate,
25
+ )
26
+
27
+ from . import common, mlp, mpu, router, sharedexpert_registry
28
+ from .arguments import Arguments
29
+ from .all_to_all import all_to_all
30
 
31
  _LOAD_BALANCING_LOSS = []
32
 
 
175
  # prior? Could we place the `torch.max` operation to return
176
  # 32-bit expert indices?
177
  top_expert = top_expert.int()
178
+ # output = ops.sort(top_expert, self.sort_end_bit)
179
+ output = sort(top_expert, self.sort_end_bit)
180
  assert output is not None
181
  bin_ids, indices = output
182
 
 
186
  # TODO(tgale): Does the sorted data produce a more favorable
187
  # data distribution for histogram? Or is the op parallelism
188
  # worth more?
189
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
190
+ tokens_per_expert = histogram(top_expert, self.num_experts)
191
 
192
  # Calculate the bin bounds for the sorted tokens.
193
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
194
+ bins = inclusive_cumsum(tokens_per_expert, 0)
195
  assert bins is not None
196
  bins = bins.view(1) if not len(bins.size()) else bins
197
 
 
215
  ):
216
  # Route the tokens for MoE computation.
217
  x = x.view(-1, x.shape[-1])
218
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
219
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
220
  assert output is not None
221
  x = output
222
 
 
225
  x = self.mlp(x)
226
 
227
  # Un-route the data for the MoE output.
228
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
229
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
230
+
231
 
232
  def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
233
  # x: [sl, bs, hs]
 
287
  # If we're sharding the experts along the hidden dimension
288
  # multiple devices own parts of the same sets of experts.
289
  # Replicate the token counts so every device gets the counts.
290
+ # repeated_tokens_per_expert = ops.repeat(
291
+ repeated_tokens_per_expert = repeat(
292
  tokens_per_expert,
293
  (mpu.hidden_sharding_degree(self.args),),
294
  )
 
309
  # This view updates the shape of the tensor from [sl, bs, hs] to
310
  # [sl * bs, hs] prior to the permutation.
311
  x = x.view(-1, x.shape[-1])
312
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
313
+ output = gather(x, indices, bin_ids, bins, self.top_k)
314
  assert output is not None
315
  x = output
316
 
 
342
  # get all of the tokens assigned to them.
343
  #
344
  # TODO(tgale): Fuse this into the prior, local permutation.
345
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
346
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
347
 
348
  # Start the cross-device permutation asynchronously so we can
349
  # overlap communication with computation.
 
362
  # for expert computation we'll do one more local permutation. The
363
  # rest of this torch.no_grad() scope sets up the indices and bins
364
  # for this permutation.
365
+ # replicate_bins = ops.inclusive_cumsum(
366
+ replicate_bins = inclusive_cumsum(
367
  parallel_tokens_per_expert.flatten(),
368
  0,
369
  )
 
378
  ),
379
  mpu.experts_per_rank(self.args),
380
  )
381
+ # parallel_top_expert = ops.replicate(
382
+ parallel_top_expert = replicate(
383
  parallel_top_expert.unsqueeze(dim=0),
384
  replicate_bins,
385
  tokens_received,
386
  ).flatten()
387
 
388
  # TODO(tgale): The sort_end_bit here can be reduced.
389
+ # parallel_bin_ids, parallel_indices = ops.sort(
390
+ parallel_bin_ids, parallel_indices = sort(
391
  parallel_top_expert,
392
  self.sort_end_bit,
393
  )
 
397
  dim=0,
398
  dtype=torch.int,
399
  )
400
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
401
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
402
  parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
403
 
404
  # If expert_capacity is set to zero, set the number of tokens
 
446
  -1,
447
  self.args.hidden_size,
448
  )
449
+ # x = ops.sum(x.view(shape), dim=0)
450
+ x = x.view(shape).sum(dim=0)
451
 
452
  # Un-permute locally to setup for the next series of operations.
453
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
454
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
455
  return x, tokens_per_expert.flatten()
456
 
457
  def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py CHANGED
@@ -6,7 +6,8 @@ from typing import Optional
6
  import torch
7
  import torch.distributed as dist
8
 
9
- from megablocks.layers.arguments import Arguments
 
10
 
11
 
12
  class MoeParam(torch.Tensor):
 
6
  import torch
7
  import torch.distributed as dist
8
 
9
+ # from megablocks.layers.arguments import Arguments
10
+ from .arguments import Arguments
11
 
12
 
13
  class MoeParam(torch.Tensor):
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py CHANGED
@@ -4,8 +4,10 @@ from typing import Any
4
 
5
  import torch
6
 
7
- from megablocks.layers import common
8
- from megablocks.layers.arguments import Arguments
 
 
9
 
10
  _ROUTER_LOGITS = []
11
 
 
4
 
5
  import torch
6
 
7
+ # from megablocks.layers import common
8
+ # from megablocks.layers.arguments import Arguments
9
+ from . import common
10
+ from .arguments import Arguments
11
 
12
  _ROUTER_LOGITS = []
13
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py CHANGED
@@ -3,8 +3,10 @@
3
 
4
  from typing import Union
5
 
6
- from megablocks.layers import glu, mlp
7
- from megablocks.layers.arguments import Arguments
 
 
8
 
9
  _REGISTRY = {
10
  'mlp': mlp.SharedMLP,
 
3
 
4
  from typing import Union
5
 
6
+ # from megablocks.layers import glu, mlp
7
+ # from megablocks.layers.arguments import Arguments
8
+ from . import glu, mlp
9
+ from .arguments import Arguments
10
 
11
  _REGISTRY = {
12
  'mlp': mlp.SharedMLP,
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py CHANGED
@@ -1,20 +1,20 @@
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
- from megablocks.ops.binned_gather import binned_gather
5
- from megablocks.ops.binned_scatter import binned_scatter
6
- from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum
7
- from megablocks.ops.gather import gather
8
- from megablocks.ops.histogram import histogram
9
- from megablocks.ops.padded_gather import padded_gather
10
- from megablocks.ops.padded_scatter import padded_scatter
11
- from megablocks.ops.repeat import repeat
12
- from megablocks.ops.replicate import replicate
13
- from megablocks.ops.round_up import round_up
14
- from megablocks.ops.scatter import scatter
15
- from megablocks.ops.sort import sort
16
- from megablocks.ops.sum import sum
17
- from megablocks.ops.topology import topology
18
 
19
  __all__ = [
20
  'binned_gather',
 
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
+ from .binned_gather import binned_gather
5
+ from .binned_scatter import binned_scatter
6
+ from .cumsum import exclusive_cumsum, inclusive_cumsum
7
+ from .gather import gather
8
+ from .histogram import histogram
9
+ from .padded_gather import padded_gather
10
+ from .padded_scatter import padded_scatter
11
+ from .repeat import repeat
12
+ from .replicate import replicate
13
+ from .round_up import round_up
14
+ from .scatter import scatter
15
+ from .sort import sort
16
+ from .sum import sum
17
+ from .topology import topology
18
 
19
  __all__ = [
20
  'binned_gather',
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -4,8 +4,11 @@
4
  import torch
5
  import torch.distributed as dist
6
 
7
- from megablocks import benchmark_util
8
- from megablocks.layers.all_to_all import all_to_all
 
 
 
9
 
10
  _ALL_TO_ALL_BENCHMARK = (
11
  (8, 1024),
 
4
  import torch
5
  import torch.distributed as dist
6
 
7
+ # from megablocks import benchmark_util
8
+ # from megablocks.layers.all_to_all import all_to_all
9
+
10
+ from .. import benchmark_util
11
+ from ..layers.all_to_all import all_to_all
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py CHANGED
@@ -5,7 +5,7 @@ from typing import Any
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
- from megablocks.backend import kernels
9
 
10
 
11
  # Autograd wrapper for binned_gather kernel.
 
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
+ from ..backend import kernels
9
 
10
 
11
  # Autograd wrapper for binned_gather kernel.
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py CHANGED
@@ -5,7 +5,7 @@ from typing import Any
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
- from megablocks.backend import kernels
9
 
10
 
11
  # Autograd wrapper for binned_scatter kernel.
 
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
+ from ..backend import kernels
9
 
10
 
11
  # Autograd wrapper for binned_scatter kernel.
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  # instructions for building the c++ operations.
12
  try:
13
  # import megablocks_ops as ops # type: ignore
14
- from megablocks._ops import ops # type: ignore
15
  except ModuleNotFoundError as e:
16
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
17
 
 
11
  # instructions for building the c++ operations.
12
  try:
13
  # import megablocks_ops as ops # type: ignore
14
+ from .._ops import ops # type: ignore
15
  except ModuleNotFoundError as e:
16
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
17
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py CHANGED
@@ -5,7 +5,7 @@ from typing import Any
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
- from megablocks.backend import kernels
9
 
10
 
11
  # Autograd wrapper for gather kernel.
 
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
+ from ..backend import kernels
9
 
10
 
11
  # Autograd wrapper for gather kernel.
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
- from megablocks._ops import ops # type: ignore
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
 
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
+ from .._ops import ops # type: ignore
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  import torch
8
  from absl.testing import parameterized
9
 
10
- from megablocks import ops
11
 
12
  _HISTOGRAM_TESTS = (
13
  (16384, torch.int32, 2),
 
7
  import torch
8
  from absl.testing import parameterized
9
 
10
+ from .. import ops
11
 
12
  _HISTOGRAM_TESTS = (
13
  (16384, torch.int32, 2),
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py CHANGED
@@ -7,7 +7,7 @@ import stk
7
  import torch
8
  from absl.testing import parameterized
9
 
10
- from megablocks import benchmark_util, ops
11
 
12
 
13
  # Calling tensor.t() calls tensor.transpose(0, 1) which calls
 
7
  import torch
8
  from absl.testing import parameterized
9
 
10
+ from .. import benchmark_util, ops
11
 
12
 
13
  # Calling tensor.t() calls tensor.transpose(0, 1) which calls
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py CHANGED
@@ -5,7 +5,7 @@ from typing import Any
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
- from megablocks.backend import kernels
9
 
10
 
11
  # Autograd wrapper for padded_gather kernel.
 
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
+ from ..backend import kernels
9
 
10
 
11
  # Autograd wrapper for padded_gather kernel.
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py CHANGED
@@ -5,7 +5,7 @@ from typing import Any
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
- from megablocks.backend import kernels
9
 
10
 
11
  # Autograd wrapper for padded_scatter kernel.
 
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
+ from ..backend import kernels
9
 
10
 
11
  # Autograd wrapper for padded_scatter kernel.
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py CHANGED
@@ -6,7 +6,7 @@ import unittest
6
  import torch
7
  from absl.testing import parameterized
8
 
9
- from megablocks import benchmark_util, ops
10
 
11
  _PADDED_SCATTER_BENCHMARK = (
12
  # dMoE-Medium, 8-way EMP.
 
6
  import torch
7
  from absl.testing import parameterized
8
 
9
+ from .. import benchmark_util, ops
10
 
11
  _PADDED_SCATTER_BENCHMARK = (
12
  # dMoE-Medium, 8-way EMP.
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py CHANGED
@@ -6,7 +6,7 @@ import unittest
6
  import torch
7
  from absl.testing import parameterized
8
 
9
- from megablocks import benchmark_util, ops
10
 
11
  _PERMUTE_TESTS = (
12
  (16384, 768, 2),
 
6
  import torch
7
  from absl.testing import parameterized
8
 
9
+ from .. import benchmark_util, ops
10
 
11
  _PERMUTE_TESTS = (
12
  (16384, 768, 2),
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
- from megablocks._ops import ops # type: ignore
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
 
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
+ from .._ops import ops # type: ignore
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py CHANGED
@@ -6,7 +6,7 @@ from typing import Any, Optional
6
  import torch
7
  from stk.backend.autocast import custom_bwd, custom_fwd
8
 
9
- from megablocks.backend import kernels
10
 
11
 
12
  # Autograd wrapper for scatter kernel.
 
6
  import torch
7
  from stk.backend.autocast import custom_bwd, custom_fwd
8
 
9
+ from ..backend import kernels
10
 
11
 
12
  # Autograd wrapper for scatter kernel.
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
- from megablocks._ops import ops # type: ignore
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
 
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
+ from .._ops import ops # type: ignore
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  import torch
8
  from absl.testing import parameterized
9
 
10
- from megablocks import ops
11
 
12
  _SORT_TESTS = (
13
  (16384, torch.int32, None),
 
7
  import torch
8
  from absl.testing import parameterized
9
 
10
+ from .. import ops
11
 
12
  _SORT_TESTS = (
13
  (16384, torch.int32, None),
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
- from megablocks._ops import ops # type: ignore
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
 
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
+ from .._ops import ops # type: ignore
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_a585153_dirty.abi3.so → _megablocks_6756875_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e734576700345e035790357ea19730e84e90c176747076ce845995bc3a0e0d50
3
  size 11869424
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1419672a07ed370d7107ca54a6b694f234efa8e696644ee4e96c1bf396aff6af
3
  size 11869424
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_a585153_dirty
3
- ops = torch.ops._megablocks_a585153_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_a585153_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_6756875_dirty
3
+ ops = torch.ops._megablocks_6756875_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_6756875_dirty::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py CHANGED
@@ -10,7 +10,8 @@ import torch
10
  # We import the backend operations from the megablocks package as
11
  # grouped_gemm is vendored in megablocks in this repository.
12
  # from ... import _ops as backend
13
- from megablocks._ops import ops as backend # type: ignore
 
14
 
15
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
16
  assert not (trans_a and trans_b)
 
10
  # We import the backend operations from the megablocks package as
11
  # grouped_gemm is vendored in megablocks in this repository.
12
  # from ... import _ops as backend
13
+ # from megablocks._ops import ops as backend # type: ignore
14
+ from .._ops import ops as backend # type: ignore
15
 
16
  def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
  assert not (trans_a and trans_b)
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py CHANGED
@@ -9,7 +9,8 @@ import torch
9
  import torch.distributed as dist
10
  import torch.nn.functional as F
11
 
12
- import megablocks.grouped_gemm_util as grouped_gemm
 
13
 
14
  # Type annotation for in-place Tensor initialization function.
15
  InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
 
9
  import torch.distributed as dist
10
  import torch.nn.functional as F
11
 
12
+ # import megablocks.grouped_gemm_util as grouped_gemm
13
+ from .. import grouped_gemm_util as grouped_gemm
14
 
15
  # Type annotation for in-place Tensor initialization function.
16
  InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py CHANGED
@@ -3,7 +3,7 @@
3
 
4
  import torch
5
 
6
- from megablocks.layers.arguments import Arguments
7
 
8
 
9
  def dtype(args: Arguments):
 
3
 
4
  import torch
5
 
6
+ from .arguments import Arguments
7
 
8
 
9
  def dtype(args: Arguments):
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py CHANGED
@@ -3,8 +3,8 @@
3
 
4
  from typing import Union
5
 
6
- from megablocks.layers import glu, mlp
7
- from megablocks.layers.arguments import Arguments
8
 
9
  MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
10
 
 
3
 
4
  from typing import Union
5
 
6
+ from . import glu, mlp
7
+ from .arguments import Arguments
8
 
9
  MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
10
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py CHANGED
@@ -6,11 +6,14 @@ import stk.ops
6
  import torch
7
  from stk import Matrix
8
 
9
- import megablocks.ops as ops
10
- # from megablocks.ops import ops
11
- from megablocks.layers import common, dmlp_registry, moe, mpu
12
- from megablocks.layers.arguments import Arguments
13
-
 
 
 
14
 
15
  def promote_scalar(x):
16
  return x.view(1) if not len(x.size()) else x
 
6
  import torch
7
  from stk import Matrix
8
 
9
+ # import megablocks.ops as ops
10
+ # # from megablocks.ops import ops
11
+ # from megablocks.layers import common, dmlp_registry, moe, mpu
12
+ # from megablocks.layers.arguments import Arguments
13
+
14
+ from .. import ops
15
+ from . import common, dmlp_registry, moe, mpu
16
+ from .arguments import Arguments
17
 
18
  def promote_scalar(x):
19
  return x.view(1) if not len(x.size()) else x
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py CHANGED
@@ -4,11 +4,22 @@
4
  import stk.ops
5
  import torch
6
 
7
- from megablocks import grouped_gemm_util as gg
8
- from megablocks.layers import common, mpu
9
- from megablocks.layers.activation_fn import act_fn
10
- from megablocks.layers.arguments import Arguments
11
- from megablocks.layers.mlp import (
 
 
 
 
 
 
 
 
 
 
 
12
  SharedMLP,
13
  SparseMLP,
14
  create_dmoe_expert_weights,
 
4
  import stk.ops
5
  import torch
6
 
7
+ # from megablocks import grouped_gemm_util as gg
8
+ # from megablocks.layers import common, mpu
9
+ # from megablocks.layers.activation_fn import act_fn
10
+ # from megablocks.layers.arguments import Arguments
11
+ # from megablocks.layers.mlp import (
12
+ # SharedMLP,
13
+ # SparseMLP,
14
+ # create_dmoe_expert_weights,
15
+ # resolve_dtensor,
16
+ # )
17
+
18
+ from .. import grouped_gemm_util as gg
19
+ from . import common, mpu
20
+ from .activation_fn import act_fn
21
+ from .arguments import Arguments
22
+ from .mlp import (
23
  SharedMLP,
24
  SparseMLP,
25
  create_dmoe_expert_weights,
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py CHANGED
@@ -6,7 +6,8 @@ import gc
6
  import torch
7
  import torch.distributed as dist
8
 
9
- from megablocks.layers import arguments, dmoe
 
10
 
11
  _TESTS = ((8, 2048, 4096, 4096, 32, 4),)
12
 
 
6
  import torch
7
  import torch.distributed as dist
8
 
9
+ # from megablocks.layers import arguments, dmoe
10
+ from . import arguments, dmoe
11
 
12
  _TESTS = ((8, 2048, 4096, 4096, 32, 4),)
13
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py CHANGED
@@ -9,11 +9,15 @@ import stk.ops
9
  import torch
10
  from packaging import version
11
 
12
- from megablocks import grouped_gemm_util as gg
13
- from megablocks.layers import common, gelu, mpu
14
- from megablocks.layers.activation_fn import act_fn
15
- from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
16
-
 
 
 
 
17
 
18
  class ScaleGradient(torch.autograd.Function):
19
 
 
9
  import torch
10
  from packaging import version
11
 
12
+ # from megablocks import grouped_gemm_util as gg
13
+ # from megablocks.layers import common, gelu, mpu
14
+ # from megablocks.layers.activation_fn import act_fn
15
+ # from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
16
+
17
+ from .. import grouped_gemm_util as gg
18
+ from . import common, gelu, mpu
19
+ from .activation_fn import act_fn
20
+ from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
21
 
22
  class ScaleGradient(torch.autograd.Function):
23
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py CHANGED
@@ -6,10 +6,27 @@ import numpy as np
6
  import torch
7
  import torch.distributed as dist
8
 
9
- import megablocks.ops as ops
10
- from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
11
- from megablocks.layers.all_to_all import all_to_all
12
- from megablocks.layers.arguments import Arguments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  _LOAD_BALANCING_LOSS = []
15
 
@@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module):
158
  # prior? Could we place the `torch.max` operation to return
159
  # 32-bit expert indices?
160
  top_expert = top_expert.int()
161
- output = ops.sort(top_expert, self.sort_end_bit)
 
162
  assert output is not None
163
  bin_ids, indices = output
164
 
@@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module):
168
  # TODO(tgale): Does the sorted data produce a more favorable
169
  # data distribution for histogram? Or is the op parallelism
170
  # worth more?
171
- tokens_per_expert = ops.histogram(top_expert, self.num_experts)
 
172
 
173
  # Calculate the bin bounds for the sorted tokens.
174
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
 
175
  assert bins is not None
176
  bins = bins.view(1) if not len(bins.size()) else bins
177
 
@@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module):
195
  ):
196
  # Route the tokens for MoE computation.
197
  x = x.view(-1, x.shape[-1])
198
- output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
 
199
  assert output is not None
200
  x = output
201
 
@@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module):
204
  x = self.mlp(x)
205
 
206
  # Un-route the data for the MoE output.
207
- return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
 
 
208
 
209
  def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
210
  # x: [sl, bs, hs]
@@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module):
264
  # If we're sharding the experts along the hidden dimension
265
  # multiple devices own parts of the same sets of experts.
266
  # Replicate the token counts so every device gets the counts.
267
- repeated_tokens_per_expert = ops.repeat(
 
268
  tokens_per_expert,
269
  (mpu.hidden_sharding_degree(self.args),),
270
  )
@@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module):
285
  # This view updates the shape of the tensor from [sl, bs, hs] to
286
  # [sl * bs, hs] prior to the permutation.
287
  x = x.view(-1, x.shape[-1])
288
- output = ops.gather(x, indices, bin_ids, bins, self.top_k)
 
289
  assert output is not None
290
  x = output
291
 
@@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module):
317
  # get all of the tokens assigned to them.
318
  #
319
  # TODO(tgale): Fuse this into the prior, local permutation.
320
- x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
 
321
 
322
  # Start the cross-device permutation asynchronously so we can
323
  # overlap communication with computation.
@@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module):
336
  # for expert computation we'll do one more local permutation. The
337
  # rest of this torch.no_grad() scope sets up the indices and bins
338
  # for this permutation.
339
- replicate_bins = ops.inclusive_cumsum(
 
340
  parallel_tokens_per_expert.flatten(),
341
  0,
342
  )
@@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module):
351
  ),
352
  mpu.experts_per_rank(self.args),
353
  )
354
- parallel_top_expert = ops.replicate(
 
355
  parallel_top_expert.unsqueeze(dim=0),
356
  replicate_bins,
357
  tokens_received,
358
  ).flatten()
359
 
360
  # TODO(tgale): The sort_end_bit here can be reduced.
361
- parallel_bin_ids, parallel_indices = ops.sort(
 
362
  parallel_top_expert,
363
  self.sort_end_bit,
364
  )
@@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module):
368
  dim=0,
369
  dtype=torch.int,
370
  )
371
- parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
372
  parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
373
 
374
  # If expert_capacity is set to zero, set the number of tokens
@@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module):
416
  -1,
417
  self.args.hidden_size,
418
  )
419
- x = ops.sum(x.view(shape), dim=0)
 
420
 
421
  # Un-permute locally to setup for the next series of operations.
422
- x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
 
423
  return x, tokens_per_expert.flatten()
424
 
425
  def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
 
6
  import torch
7
  import torch.distributed as dist
8
 
9
+ # import megablocks.ops as ops
10
+ # from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
11
+ # from megablocks.layers.all_to_all import all_to_all
12
+ # from megablocks.layers.arguments import Arguments
13
+
14
+ from ..ops import (
15
+ sort,
16
+ histogram,
17
+ inclusive_cumsum,
18
+ exclusive_cumsum,
19
+ binned_gather,
20
+ binned_scatter,
21
+ gather,
22
+ scatter,
23
+ repeat,
24
+ replicate,
25
+ )
26
+
27
+ from . import common, mlp, mpu, router, sharedexpert_registry
28
+ from .arguments import Arguments
29
+ from .all_to_all import all_to_all
30
 
31
  _LOAD_BALANCING_LOSS = []
32
 
 
175
  # prior? Could we place the `torch.max` operation to return
176
  # 32-bit expert indices?
177
  top_expert = top_expert.int()
178
+ # output = ops.sort(top_expert, self.sort_end_bit)
179
+ output = sort(top_expert, self.sort_end_bit)
180
  assert output is not None
181
  bin_ids, indices = output
182
 
 
186
  # TODO(tgale): Does the sorted data produce a more favorable
187
  # data distribution for histogram? Or is the op parallelism
188
  # worth more?
189
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
190
+ tokens_per_expert = histogram(top_expert, self.num_experts)
191
 
192
  # Calculate the bin bounds for the sorted tokens.
193
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
194
+ bins = inclusive_cumsum(tokens_per_expert, 0)
195
  assert bins is not None
196
  bins = bins.view(1) if not len(bins.size()) else bins
197
 
 
215
  ):
216
  # Route the tokens for MoE computation.
217
  x = x.view(-1, x.shape[-1])
218
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
219
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
220
  assert output is not None
221
  x = output
222
 
 
225
  x = self.mlp(x)
226
 
227
  # Un-route the data for the MoE output.
228
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
229
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
230
+
231
 
232
  def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
233
  # x: [sl, bs, hs]
 
287
  # If we're sharding the experts along the hidden dimension
288
  # multiple devices own parts of the same sets of experts.
289
  # Replicate the token counts so every device gets the counts.
290
+ # repeated_tokens_per_expert = ops.repeat(
291
+ repeated_tokens_per_expert = repeat(
292
  tokens_per_expert,
293
  (mpu.hidden_sharding_degree(self.args),),
294
  )
 
309
  # This view updates the shape of the tensor from [sl, bs, hs] to
310
  # [sl * bs, hs] prior to the permutation.
311
  x = x.view(-1, x.shape[-1])
312
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
313
+ output = gather(x, indices, bin_ids, bins, self.top_k)
314
  assert output is not None
315
  x = output
316
 
 
342
  # get all of the tokens assigned to them.
343
  #
344
  # TODO(tgale): Fuse this into the prior, local permutation.
345
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
346
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
347
 
348
  # Start the cross-device permutation asynchronously so we can
349
  # overlap communication with computation.
 
362
  # for expert computation we'll do one more local permutation. The
363
  # rest of this torch.no_grad() scope sets up the indices and bins
364
  # for this permutation.
365
+ # replicate_bins = ops.inclusive_cumsum(
366
+ replicate_bins = inclusive_cumsum(
367
  parallel_tokens_per_expert.flatten(),
368
  0,
369
  )
 
378
  ),
379
  mpu.experts_per_rank(self.args),
380
  )
381
+ # parallel_top_expert = ops.replicate(
382
+ parallel_top_expert = replicate(
383
  parallel_top_expert.unsqueeze(dim=0),
384
  replicate_bins,
385
  tokens_received,
386
  ).flatten()
387
 
388
  # TODO(tgale): The sort_end_bit here can be reduced.
389
+ # parallel_bin_ids, parallel_indices = ops.sort(
390
+ parallel_bin_ids, parallel_indices = sort(
391
  parallel_top_expert,
392
  self.sort_end_bit,
393
  )
 
397
  dim=0,
398
  dtype=torch.int,
399
  )
400
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
401
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
402
  parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
403
 
404
  # If expert_capacity is set to zero, set the number of tokens
 
446
  -1,
447
  self.args.hidden_size,
448
  )
449
+ # x = ops.sum(x.view(shape), dim=0)
450
+ x = x.view(shape).sum(dim=0)
451
 
452
  # Un-permute locally to setup for the next series of operations.
453
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
454
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
455
  return x, tokens_per_expert.flatten()
456
 
457
  def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py CHANGED
@@ -6,7 +6,8 @@ from typing import Optional
6
  import torch
7
  import torch.distributed as dist
8
 
9
- from megablocks.layers.arguments import Arguments
 
10
 
11
 
12
  class MoeParam(torch.Tensor):
 
6
  import torch
7
  import torch.distributed as dist
8
 
9
+ # from megablocks.layers.arguments import Arguments
10
+ from .arguments import Arguments
11
 
12
 
13
  class MoeParam(torch.Tensor):
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py CHANGED
@@ -4,8 +4,10 @@ from typing import Any
4
 
5
  import torch
6
 
7
- from megablocks.layers import common
8
- from megablocks.layers.arguments import Arguments
 
 
9
 
10
  _ROUTER_LOGITS = []
11
 
 
4
 
5
  import torch
6
 
7
+ # from megablocks.layers import common
8
+ # from megablocks.layers.arguments import Arguments
9
+ from . import common
10
+ from .arguments import Arguments
11
 
12
  _ROUTER_LOGITS = []
13
 
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py CHANGED
@@ -3,8 +3,10 @@
3
 
4
  from typing import Union
5
 
6
- from megablocks.layers import glu, mlp
7
- from megablocks.layers.arguments import Arguments
 
 
8
 
9
  _REGISTRY = {
10
  'mlp': mlp.SharedMLP,
 
3
 
4
  from typing import Union
5
 
6
+ # from megablocks.layers import glu, mlp
7
+ # from megablocks.layers.arguments import Arguments
8
+ from . import glu, mlp
9
+ from .arguments import Arguments
10
 
11
  _REGISTRY = {
12
  'mlp': mlp.SharedMLP,
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py CHANGED
@@ -1,20 +1,20 @@
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
- from megablocks.ops.binned_gather import binned_gather
5
- from megablocks.ops.binned_scatter import binned_scatter
6
- from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum
7
- from megablocks.ops.gather import gather
8
- from megablocks.ops.histogram import histogram
9
- from megablocks.ops.padded_gather import padded_gather
10
- from megablocks.ops.padded_scatter import padded_scatter
11
- from megablocks.ops.repeat import repeat
12
- from megablocks.ops.replicate import replicate
13
- from megablocks.ops.round_up import round_up
14
- from megablocks.ops.scatter import scatter
15
- from megablocks.ops.sort import sort
16
- from megablocks.ops.sum import sum
17
- from megablocks.ops.topology import topology
18
 
19
  __all__ = [
20
  'binned_gather',
 
1
  # Copyright 2024 Databricks
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
+ from .binned_gather import binned_gather
5
+ from .binned_scatter import binned_scatter
6
+ from .cumsum import exclusive_cumsum, inclusive_cumsum
7
+ from .gather import gather
8
+ from .histogram import histogram
9
+ from .padded_gather import padded_gather
10
+ from .padded_scatter import padded_scatter
11
+ from .repeat import repeat
12
+ from .replicate import replicate
13
+ from .round_up import round_up
14
+ from .scatter import scatter
15
+ from .sort import sort
16
+ from .sum import sum
17
+ from .topology import topology
18
 
19
  __all__ = [
20
  'binned_gather',
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -4,8 +4,11 @@
4
  import torch
5
  import torch.distributed as dist
6
 
7
- from megablocks import benchmark_util
8
- from megablocks.layers.all_to_all import all_to_all
 
 
 
9
 
10
  _ALL_TO_ALL_BENCHMARK = (
11
  (8, 1024),
 
4
  import torch
5
  import torch.distributed as dist
6
 
7
+ # from megablocks import benchmark_util
8
+ # from megablocks.layers.all_to_all import all_to_all
9
+
10
+ from .. import benchmark_util
11
+ from ..layers.all_to_all import all_to_all
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py CHANGED
@@ -5,7 +5,7 @@ from typing import Any
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
- from megablocks.backend import kernels
9
 
10
 
11
  # Autograd wrapper for binned_gather kernel.
 
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
+ from ..backend import kernels
9
 
10
 
11
  # Autograd wrapper for binned_gather kernel.
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py CHANGED
@@ -5,7 +5,7 @@ from typing import Any
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
- from megablocks.backend import kernels
9
 
10
 
11
  # Autograd wrapper for binned_scatter kernel.
 
5
  import torch
6
  from stk.backend.autocast import custom_bwd, custom_fwd
7
 
8
+ from ..backend import kernels
9
 
10
 
11
  # Autograd wrapper for binned_scatter kernel.