diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7e041eea505b2cea718255f91475aaf5ed5262b5 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad46e9f244afa886c8a104d75e37f93afd2a0ecf83bfc7a414680fa16d8b78f9 +size 10517608 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so deleted file mode 100755 index 81c1d31b49a30085121ea47c012801b4f6fe3a00..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:44462d45f75616c369c2421fe41d53cd1d1dc365f1d2545d870e2db999e67e38 -size 10517608 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py index fb18e304d3a97b1801f4f28ccd9b74403f400f6f..c9a9faa3cee07c8b7016a71704ea26e47c460d57 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_a585153_dirty -ops = torch.ops._megablocks_a585153_dirty +from . import _megablocks_6756875_dirty +ops = torch.ops._megablocks_6756875_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py index 7c692f95852950a7b97aaaa2aa3d15e4297b6192..76037d8039cbfc2f0577275c78e4bc0be762592a 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -10,7 +10,8 @@ import torch # We import the backend operations from the megablocks package as # grouped_gemm is vendored in megablocks in this repository. # from ... import _ops as backend -from megablocks._ops import ops as backend # type: ignore +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py index 3962c771c90012535aab058443f89c541a2e9236..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py @@ -9,7 +9,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import megablocks.grouped_gemm_util as grouped_gemm +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py index ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927..2d07109702963ba48a3b94ab860807954dfd79c1 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py @@ -3,7 +3,7 @@ import torch -from megablocks.layers.arguments import Arguments +from .arguments import Arguments def dtype(args: Arguments): diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py index d765bd04387a29bddd2789fae04821635b555e82..de2ed047042e438c7190ebb139b6f7f30009734c 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -3,8 +3,8 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py index 205727ff4d63f9e8dc9648acaac99a97f3394d6f..561fd7982254513e9a9b4f0514edd326b0561256 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py @@ -6,11 +6,14 @@ import stk.ops import torch from stk import Matrix -import megablocks.ops as ops -# from megablocks.ops import ops -from megablocks.layers import common, dmlp_registry, moe, mpu -from megablocks.layers.arguments import Arguments - +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments def promote_scalar(x): return x.view(1) if not len(x.size()) else x diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py index cbe0c915c307e7f7cade3ea3ff679399635fcd81..69ef9ead5461a5b91ace107652fd8d1c3d300b1e 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py @@ -4,11 +4,22 @@ import stk.ops import torch -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import Arguments -from megablocks.layers.mlp import ( +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( SharedMLP, SparseMLP, create_dmoe_expert_weights, diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py index 4acbd94f212ce906ace9de78dd3ffc6afa03f97e..74d1166931b712635131985b25a89f4ca23e576d 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py @@ -6,7 +6,8 @@ import gc import torch import torch.distributed as dist -from megablocks.layers import arguments, dmoe +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe _TESTS = ((8, 2048, 4096, 4096, 32, 4),) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py index 6e6f4d82441e3a9c185db5bfdf686d53790dde26..0b6747b1eec2047eb9480592ca550b6c9e868f5e 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py @@ -9,11 +9,15 @@ import stk.ops import torch from packaging import version -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, gelu, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py index 9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e..d0a4aeaacc9c86fc70944e730c53f7a55644e05e 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py @@ -6,10 +6,27 @@ import numpy as np import torch import torch.distributed as dist -import megablocks.ops as ops -from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all _LOAD_BALANCING_LOSS = [] @@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - output = ops.sort(top_expert, self.sort_end_bit) + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) assert output is not None bin_ids, indices = output @@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module): # TODO(tgale): Does the sorted data produce a more favorable # data distribution for histogram? Or is the op parallelism # worth more? - tokens_per_expert = ops.histogram(top_expert, self.num_experts) + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins @@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module): ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) assert output is not None x = output @@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module): x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] @@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module): # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. - repeated_tokens_per_expert = ops.repeat( + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( tokens_per_expert, (mpu.hidden_sharding_degree(self.args),), ) @@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - output = ops.gather(x, indices, bin_ids, bins, self.top_k) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) assert output is not None x = output @@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module): # get all of the tokens assigned to them. # # TODO(tgale): Fuse this into the prior, local permutation. - x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) # Start the cross-device permutation asynchronously so we can # overlap communication with computation. @@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module): # for expert computation we'll do one more local permutation. The # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. - replicate_bins = ops.inclusive_cumsum( + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( parallel_tokens_per_expert.flatten(), 0, ) @@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module): ), mpu.experts_per_rank(self.args), ) - parallel_top_expert = ops.replicate( + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received, ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. - parallel_bin_ids, parallel_indices = ops.sort( + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( parallel_top_expert, self.sort_end_bit, ) @@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module): dim=0, dtype=torch.int, ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens @@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module): -1, self.args.hidden_size, ) - x = ops.sum(x.view(shape), dim=0) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py index b23213902f4567d7fdb0158cbcf5406c2b2aa601..434e143ab42bf3f83406d69e9dd1f72777716e22 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py @@ -6,7 +6,8 @@ from typing import Optional import torch import torch.distributed as dist -from megablocks.layers.arguments import Arguments +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments class MoeParam(torch.Tensor): diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py index 2c9dcd9e433322482a80d5b95afccee5c12368f8..37cb2782348d62583376f1a183c7ede83601216d 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py @@ -4,8 +4,10 @@ from typing import Any import torch -from megablocks.layers import common -from megablocks.layers.arguments import Arguments +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments _ROUTER_LOGITS = [] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py index 0f62db39a2dda5642f6baa9d9b949f7c31cf6d35..5840862f88f370ace5fd49bd0612fc98d186cc49 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -3,8 +3,10 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments _REGISTRY = { 'mlp': mlp.SharedMLP, diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py index b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c..b944080df810d0b0cfc571f3009b0098a651f9b7 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py @@ -1,20 +1,20 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from megablocks.ops.binned_gather import binned_gather -from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum -from megablocks.ops.gather import gather -from megablocks.ops.histogram import histogram -from megablocks.ops.padded_gather import padded_gather -from megablocks.ops.padded_scatter import padded_scatter -from megablocks.ops.repeat import repeat -from megablocks.ops.replicate import replicate -from megablocks.ops.round_up import round_up -from megablocks.ops.scatter import scatter -from megablocks.ops.sort import sort -from megablocks.ops.sum import sum -from megablocks.ops.topology import topology +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology __all__ = [ 'binned_gather', diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 47b95301dbb35b0c3154d1ff2878d20792ce5cb2..43d267dbe2570e4a1f59fd561398668cc2bc0920 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -4,8 +4,11 @@ import torch import torch.distributed as dist -from megablocks import benchmark_util -from megablocks.layers.all_to_all import all_to_all +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from ..layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py index 89cce1b627137d7f987baca62fd6d82c5c04659a..8ae2ad8388b06db46a13f8fa46083619c44eefe2 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_gather kernel. diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py index f5ce0d6f5a890a58b89e6a501216a9822341323f..6d8654bdc718cd893f7899e9cdb6cd20544d189c 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_scatter kernel. diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py index 0d7c5137c900d66a959a76fa681436362a4df906..e2b7572391e20045d335cf7337246e8a9b9f57ef 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py @@ -11,7 +11,7 @@ import torch # instructions for building the c++ operations. try: # import megablocks_ops as ops # type: ignore - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py index 41b09a1233e8a996ff1062579cd0810d095ad1e6..4edf4541dac52abab94151efb414acaa7711f8f6 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for gather kernel. diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py index 77e06a29163aa335dc61e8c50932d5f83322ae1d..7b3f058ec373cbba7555704fb5e4212c3cc75d9d 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py index 9de8e65281c829ba63b5c60853dfa5bb0a333988..c57b7bf8228e01237236748147368b09ffdf8072 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py index bfa7b7c2f697f9454c5381e31ee78040be5b229e..9b2c7f062071758def741bf3af2392d3a4855f89 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -7,7 +7,7 @@ import stk import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py index f272a7768dc6468e81bd2cd25294ca16a6826c08..0ffe1369d6adbce5c4a54155c636d1a4c022a41d 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_gather kernel. diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py index 9ff81dd9456a04258346266e9e85871bda56c65b..6685b0b83e7b1d92e7a772715991ead8ad94b153 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_scatter kernel. diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py index 81dde4e4b4c0f550c6027390e8de285fa4d842f7..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py index 837f07e2858d5c65c14e3f262617e7d2985ce901..6536eeeae402659a087e5c51ef9840627af56501 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py index e570e22093339d1b71dc830edf656897c67f40e7..26daf0eede330603a4b8ea7167faf1411d07ca93 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py index a5aaafc402a651a47595c25dbf71e382903e5022..3f4bacf3422543e11abb795c279876147f8610a8 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py @@ -6,7 +6,7 @@ from typing import Any, Optional import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for scatter kernel. diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py index a64a4553323d8de71a1d5f1821bac50249c68197..bda3bf64283e39533c2eae3627e76bb2d0262c9f 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py index f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec..a92ff957d4c552c6e61d9279a7989795472af7b7 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _SORT_TESTS = ( (16384, torch.int32, None), diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py index b4a693a82a36c052e51afb4a02077e7849a4abf1..76a50d3164db20534b099dcb4d8487a7aef25d15 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..8c2fc7cc25015ef81d73aa2c0231d1af2d6b9c86 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1419672a07ed370d7107ca54a6b694f234efa8e696644ee4e96c1bf396aff6af +size 11869424 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so deleted file mode 100755 index d1d1046bc8625e4dd1bd60290bf9b8f659f3247b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e734576700345e035790357ea19730e84e90c176747076ce845995bc3a0e0d50 -size 11869424 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py index fb18e304d3a97b1801f4f28ccd9b74403f400f6f..c9a9faa3cee07c8b7016a71704ea26e47c460d57 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_a585153_dirty -ops = torch.ops._megablocks_a585153_dirty +from . import _megablocks_6756875_dirty +ops = torch.ops._megablocks_6756875_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py index 7c692f95852950a7b97aaaa2aa3d15e4297b6192..76037d8039cbfc2f0577275c78e4bc0be762592a 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -10,7 +10,8 @@ import torch # We import the backend operations from the megablocks package as # grouped_gemm is vendored in megablocks in this repository. # from ... import _ops as backend -from megablocks._ops import ops as backend # type: ignore +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py index 3962c771c90012535aab058443f89c541a2e9236..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py @@ -9,7 +9,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import megablocks.grouped_gemm_util as grouped_gemm +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py index ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927..2d07109702963ba48a3b94ab860807954dfd79c1 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py @@ -3,7 +3,7 @@ import torch -from megablocks.layers.arguments import Arguments +from .arguments import Arguments def dtype(args: Arguments): diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py index d765bd04387a29bddd2789fae04821635b555e82..de2ed047042e438c7190ebb139b6f7f30009734c 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -3,8 +3,8 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py index 205727ff4d63f9e8dc9648acaac99a97f3394d6f..561fd7982254513e9a9b4f0514edd326b0561256 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py @@ -6,11 +6,14 @@ import stk.ops import torch from stk import Matrix -import megablocks.ops as ops -# from megablocks.ops import ops -from megablocks.layers import common, dmlp_registry, moe, mpu -from megablocks.layers.arguments import Arguments - +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments def promote_scalar(x): return x.view(1) if not len(x.size()) else x diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py index cbe0c915c307e7f7cade3ea3ff679399635fcd81..69ef9ead5461a5b91ace107652fd8d1c3d300b1e 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py @@ -4,11 +4,22 @@ import stk.ops import torch -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import Arguments -from megablocks.layers.mlp import ( +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( SharedMLP, SparseMLP, create_dmoe_expert_weights, diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py index 4acbd94f212ce906ace9de78dd3ffc6afa03f97e..74d1166931b712635131985b25a89f4ca23e576d 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py @@ -6,7 +6,8 @@ import gc import torch import torch.distributed as dist -from megablocks.layers import arguments, dmoe +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe _TESTS = ((8, 2048, 4096, 4096, 32, 4),) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py index 6e6f4d82441e3a9c185db5bfdf686d53790dde26..0b6747b1eec2047eb9480592ca550b6c9e868f5e 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py @@ -9,11 +9,15 @@ import stk.ops import torch from packaging import version -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, gelu, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py index 9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e..d0a4aeaacc9c86fc70944e730c53f7a55644e05e 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py @@ -6,10 +6,27 @@ import numpy as np import torch import torch.distributed as dist -import megablocks.ops as ops -from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all _LOAD_BALANCING_LOSS = [] @@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - output = ops.sort(top_expert, self.sort_end_bit) + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) assert output is not None bin_ids, indices = output @@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module): # TODO(tgale): Does the sorted data produce a more favorable # data distribution for histogram? Or is the op parallelism # worth more? - tokens_per_expert = ops.histogram(top_expert, self.num_experts) + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins @@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module): ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) assert output is not None x = output @@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module): x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] @@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module): # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. - repeated_tokens_per_expert = ops.repeat( + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( tokens_per_expert, (mpu.hidden_sharding_degree(self.args),), ) @@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - output = ops.gather(x, indices, bin_ids, bins, self.top_k) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) assert output is not None x = output @@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module): # get all of the tokens assigned to them. # # TODO(tgale): Fuse this into the prior, local permutation. - x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) # Start the cross-device permutation asynchronously so we can # overlap communication with computation. @@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module): # for expert computation we'll do one more local permutation. The # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. - replicate_bins = ops.inclusive_cumsum( + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( parallel_tokens_per_expert.flatten(), 0, ) @@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module): ), mpu.experts_per_rank(self.args), ) - parallel_top_expert = ops.replicate( + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received, ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. - parallel_bin_ids, parallel_indices = ops.sort( + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( parallel_top_expert, self.sort_end_bit, ) @@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module): dim=0, dtype=torch.int, ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens @@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module): -1, self.args.hidden_size, ) - x = ops.sum(x.view(shape), dim=0) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py index b23213902f4567d7fdb0158cbcf5406c2b2aa601..434e143ab42bf3f83406d69e9dd1f72777716e22 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py @@ -6,7 +6,8 @@ from typing import Optional import torch import torch.distributed as dist -from megablocks.layers.arguments import Arguments +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments class MoeParam(torch.Tensor): diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py index 2c9dcd9e433322482a80d5b95afccee5c12368f8..37cb2782348d62583376f1a183c7ede83601216d 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py @@ -4,8 +4,10 @@ from typing import Any import torch -from megablocks.layers import common -from megablocks.layers.arguments import Arguments +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments _ROUTER_LOGITS = [] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py index 0f62db39a2dda5642f6baa9d9b949f7c31cf6d35..5840862f88f370ace5fd49bd0612fc98d186cc49 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -3,8 +3,10 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments _REGISTRY = { 'mlp': mlp.SharedMLP, diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py index b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c..b944080df810d0b0cfc571f3009b0098a651f9b7 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py @@ -1,20 +1,20 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from megablocks.ops.binned_gather import binned_gather -from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum -from megablocks.ops.gather import gather -from megablocks.ops.histogram import histogram -from megablocks.ops.padded_gather import padded_gather -from megablocks.ops.padded_scatter import padded_scatter -from megablocks.ops.repeat import repeat -from megablocks.ops.replicate import replicate -from megablocks.ops.round_up import round_up -from megablocks.ops.scatter import scatter -from megablocks.ops.sort import sort -from megablocks.ops.sum import sum -from megablocks.ops.topology import topology +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology __all__ = [ 'binned_gather', diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 47b95301dbb35b0c3154d1ff2878d20792ce5cb2..43d267dbe2570e4a1f59fd561398668cc2bc0920 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -4,8 +4,11 @@ import torch import torch.distributed as dist -from megablocks import benchmark_util -from megablocks.layers.all_to_all import all_to_all +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from ..layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py index 89cce1b627137d7f987baca62fd6d82c5c04659a..8ae2ad8388b06db46a13f8fa46083619c44eefe2 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_gather kernel. diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py index f5ce0d6f5a890a58b89e6a501216a9822341323f..6d8654bdc718cd893f7899e9cdb6cd20544d189c 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_scatter kernel. diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/cumsum.py index 0d7c5137c900d66a959a76fa681436362a4df906..e2b7572391e20045d335cf7337246e8a9b9f57ef 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/cumsum.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/cumsum.py @@ -11,7 +11,7 @@ import torch # instructions for building the c++ operations. try: # import megablocks_ops as ops # type: ignore - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py index 41b09a1233e8a996ff1062579cd0810d095ad1e6..4edf4541dac52abab94151efb414acaa7711f8f6 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for gather kernel. diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram.py index 77e06a29163aa335dc61e8c50932d5f83322ae1d..7b3f058ec373cbba7555704fb5e4212c3cc75d9d 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py index 9de8e65281c829ba63b5c60853dfa5bb0a333988..c57b7bf8228e01237236748147368b09ffdf8072 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py index bfa7b7c2f697f9454c5381e31ee78040be5b229e..9b2c7f062071758def741bf3af2392d3a4855f89 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -7,7 +7,7 @@ import stk import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py index f272a7768dc6468e81bd2cd25294ca16a6826c08..0ffe1369d6adbce5c4a54155c636d1a4c022a41d 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_gather kernel. diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py index 9ff81dd9456a04258346266e9e85871bda56c65b..6685b0b83e7b1d92e7a772715991ead8ad94b153 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_scatter kernel. diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py index 81dde4e4b4c0f550c6027390e8de285fa4d842f7..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py index 837f07e2858d5c65c14e3f262617e7d2985ce901..6536eeeae402659a087e5c51ef9840627af56501 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/replicate.py index e570e22093339d1b71dc830edf656897c67f40e7..26daf0eede330603a4b8ea7167faf1411d07ca93 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/replicate.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/replicate.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py index a5aaafc402a651a47595c25dbf71e382903e5022..3f4bacf3422543e11abb795c279876147f8610a8 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py @@ -6,7 +6,7 @@ from typing import Any, Optional import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for scatter kernel. diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort.py index a64a4553323d8de71a1d5f1821bac50249c68197..bda3bf64283e39533c2eae3627e76bb2d0262c9f 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py index f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec..a92ff957d4c552c6e61d9279a7989795472af7b7 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _SORT_TESTS = ( (16384, torch.int32, None), diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/topology.py index b4a693a82a36c052e51afb4a02077e7849a4abf1..76a50d3164db20534b099dcb4d8487a7aef25d15 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/topology.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/topology.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..ac1dda0ba2b7480d7680927a42a5de7be43762d2 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:746c9b13bc0a8cbf0524a573252beb4a7490dda851e0246aa4e38ac3828c62e7 +size 11931080 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so deleted file mode 100755 index dbff38a1256f476e63a04bf2d924acda689595c4..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8507dd1e6fc8f4df45af233d506ef96b962cacecf9e2d0694247547b0dd7dde0 -size 11931080 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py index fb18e304d3a97b1801f4f28ccd9b74403f400f6f..c9a9faa3cee07c8b7016a71704ea26e47c460d57 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_a585153_dirty -ops = torch.ops._megablocks_a585153_dirty +from . import _megablocks_6756875_dirty +ops = torch.ops._megablocks_6756875_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py index 7c692f95852950a7b97aaaa2aa3d15e4297b6192..76037d8039cbfc2f0577275c78e4bc0be762592a 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -10,7 +10,8 @@ import torch # We import the backend operations from the megablocks package as # grouped_gemm is vendored in megablocks in this repository. # from ... import _ops as backend -from megablocks._ops import ops as backend # type: ignore +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py index 3962c771c90012535aab058443f89c541a2e9236..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py @@ -9,7 +9,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import megablocks.grouped_gemm_util as grouped_gemm +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/common.py index ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927..2d07109702963ba48a3b94ab860807954dfd79c1 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/common.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/common.py @@ -3,7 +3,7 @@ import torch -from megablocks.layers.arguments import Arguments +from .arguments import Arguments def dtype(args: Arguments): diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py index d765bd04387a29bddd2789fae04821635b555e82..de2ed047042e438c7190ebb139b6f7f30009734c 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -3,8 +3,8 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py index 205727ff4d63f9e8dc9648acaac99a97f3394d6f..561fd7982254513e9a9b4f0514edd326b0561256 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py @@ -6,11 +6,14 @@ import stk.ops import torch from stk import Matrix -import megablocks.ops as ops -# from megablocks.ops import ops -from megablocks.layers import common, dmlp_registry, moe, mpu -from megablocks.layers.arguments import Arguments - +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments def promote_scalar(x): return x.view(1) if not len(x.size()) else x diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py index cbe0c915c307e7f7cade3ea3ff679399635fcd81..69ef9ead5461a5b91ace107652fd8d1c3d300b1e 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py @@ -4,11 +4,22 @@ import stk.ops import torch -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import Arguments -from megablocks.layers.mlp import ( +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( SharedMLP, SparseMLP, create_dmoe_expert_weights, diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py index 4acbd94f212ce906ace9de78dd3ffc6afa03f97e..74d1166931b712635131985b25a89f4ca23e576d 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py @@ -6,7 +6,8 @@ import gc import torch import torch.distributed as dist -from megablocks.layers import arguments, dmoe +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe _TESTS = ((8, 2048, 4096, 4096, 32, 4),) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py index 6e6f4d82441e3a9c185db5bfdf686d53790dde26..0b6747b1eec2047eb9480592ca550b6c9e868f5e 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py @@ -9,11 +9,15 @@ import stk.ops import torch from packaging import version -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, gelu, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py index 9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e..d0a4aeaacc9c86fc70944e730c53f7a55644e05e 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py @@ -6,10 +6,27 @@ import numpy as np import torch import torch.distributed as dist -import megablocks.ops as ops -from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all _LOAD_BALANCING_LOSS = [] @@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - output = ops.sort(top_expert, self.sort_end_bit) + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) assert output is not None bin_ids, indices = output @@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module): # TODO(tgale): Does the sorted data produce a more favorable # data distribution for histogram? Or is the op parallelism # worth more? - tokens_per_expert = ops.histogram(top_expert, self.num_experts) + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins @@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module): ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) assert output is not None x = output @@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module): x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] @@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module): # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. - repeated_tokens_per_expert = ops.repeat( + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( tokens_per_expert, (mpu.hidden_sharding_degree(self.args),), ) @@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - output = ops.gather(x, indices, bin_ids, bins, self.top_k) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) assert output is not None x = output @@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module): # get all of the tokens assigned to them. # # TODO(tgale): Fuse this into the prior, local permutation. - x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) # Start the cross-device permutation asynchronously so we can # overlap communication with computation. @@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module): # for expert computation we'll do one more local permutation. The # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. - replicate_bins = ops.inclusive_cumsum( + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( parallel_tokens_per_expert.flatten(), 0, ) @@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module): ), mpu.experts_per_rank(self.args), ) - parallel_top_expert = ops.replicate( + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received, ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. - parallel_bin_ids, parallel_indices = ops.sort( + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( parallel_top_expert, self.sort_end_bit, ) @@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module): dim=0, dtype=torch.int, ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens @@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module): -1, self.args.hidden_size, ) - x = ops.sum(x.view(shape), dim=0) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py index b23213902f4567d7fdb0158cbcf5406c2b2aa601..434e143ab42bf3f83406d69e9dd1f72777716e22 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py @@ -6,7 +6,8 @@ from typing import Optional import torch import torch.distributed as dist -from megablocks.layers.arguments import Arguments +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments class MoeParam(torch.Tensor): diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/router.py index 2c9dcd9e433322482a80d5b95afccee5c12368f8..37cb2782348d62583376f1a183c7ede83601216d 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/router.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/router.py @@ -4,8 +4,10 @@ from typing import Any import torch -from megablocks.layers import common -from megablocks.layers.arguments import Arguments +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments _ROUTER_LOGITS = [] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py index 0f62db39a2dda5642f6baa9d9b949f7c31cf6d35..5840862f88f370ace5fd49bd0612fc98d186cc49 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -3,8 +3,10 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments _REGISTRY = { 'mlp': mlp.SharedMLP, diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py index b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c..b944080df810d0b0cfc571f3009b0098a651f9b7 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py @@ -1,20 +1,20 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from megablocks.ops.binned_gather import binned_gather -from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum -from megablocks.ops.gather import gather -from megablocks.ops.histogram import histogram -from megablocks.ops.padded_gather import padded_gather -from megablocks.ops.padded_scatter import padded_scatter -from megablocks.ops.repeat import repeat -from megablocks.ops.replicate import replicate -from megablocks.ops.round_up import round_up -from megablocks.ops.scatter import scatter -from megablocks.ops.sort import sort -from megablocks.ops.sum import sum -from megablocks.ops.topology import topology +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology __all__ = [ 'binned_gather', diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 47b95301dbb35b0c3154d1ff2878d20792ce5cb2..43d267dbe2570e4a1f59fd561398668cc2bc0920 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -4,8 +4,11 @@ import torch import torch.distributed as dist -from megablocks import benchmark_util -from megablocks.layers.all_to_all import all_to_all +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from ..layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py index 89cce1b627137d7f987baca62fd6d82c5c04659a..8ae2ad8388b06db46a13f8fa46083619c44eefe2 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_gather kernel. diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py index f5ce0d6f5a890a58b89e6a501216a9822341323f..6d8654bdc718cd893f7899e9cdb6cd20544d189c 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_scatter kernel. diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py index 0d7c5137c900d66a959a76fa681436362a4df906..e2b7572391e20045d335cf7337246e8a9b9f57ef 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py @@ -11,7 +11,7 @@ import torch # instructions for building the c++ operations. try: # import megablocks_ops as ops # type: ignore - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py index 41b09a1233e8a996ff1062579cd0810d095ad1e6..4edf4541dac52abab94151efb414acaa7711f8f6 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for gather kernel. diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py index 77e06a29163aa335dc61e8c50932d5f83322ae1d..7b3f058ec373cbba7555704fb5e4212c3cc75d9d 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py index 9de8e65281c829ba63b5c60853dfa5bb0a333988..c57b7bf8228e01237236748147368b09ffdf8072 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py index bfa7b7c2f697f9454c5381e31ee78040be5b229e..9b2c7f062071758def741bf3af2392d3a4855f89 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -7,7 +7,7 @@ import stk import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py index f272a7768dc6468e81bd2cd25294ca16a6826c08..0ffe1369d6adbce5c4a54155c636d1a4c022a41d 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_gather kernel. diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py index 9ff81dd9456a04258346266e9e85871bda56c65b..6685b0b83e7b1d92e7a772715991ead8ad94b153 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_scatter kernel. diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py index 81dde4e4b4c0f550c6027390e8de285fa4d842f7..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py index 837f07e2858d5c65c14e3f262617e7d2985ce901..6536eeeae402659a087e5c51ef9840627af56501 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py index e570e22093339d1b71dc830edf656897c67f40e7..26daf0eede330603a4b8ea7167faf1411d07ca93 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py index a5aaafc402a651a47595c25dbf71e382903e5022..3f4bacf3422543e11abb795c279876147f8610a8 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py @@ -6,7 +6,7 @@ from typing import Any, Optional import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for scatter kernel. diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py index a64a4553323d8de71a1d5f1821bac50249c68197..bda3bf64283e39533c2eae3627e76bb2d0262c9f 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py index f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec..a92ff957d4c552c6e61d9279a7989795472af7b7 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _SORT_TESTS = ( (16384, torch.int32, None), diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py index b4a693a82a36c052e51afb4a02077e7849a4abf1..76a50d3164db20534b099dcb4d8487a7aef25d15 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..a8f46771a14d3e54f5cbe5d1ae726dd1888a70c1 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2fbfef8842d7110f0e496806edee088172cf4fc2e63e1782e8b6a3008bd3d304 +size 10510072 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so deleted file mode 100755 index d7813fa66f923338b76230f282683661a5ff4022..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6dc0dcea20fc1350689addf7cb9927f7bb709f68ed89d4c711b0f7db579a463b -size 10510072 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py index fb18e304d3a97b1801f4f28ccd9b74403f400f6f..c9a9faa3cee07c8b7016a71704ea26e47c460d57 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_a585153_dirty -ops = torch.ops._megablocks_a585153_dirty +from . import _megablocks_6756875_dirty +ops = torch.ops._megablocks_6756875_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py index 7c692f95852950a7b97aaaa2aa3d15e4297b6192..76037d8039cbfc2f0577275c78e4bc0be762592a 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -10,7 +10,8 @@ import torch # We import the backend operations from the megablocks package as # grouped_gemm is vendored in megablocks in this repository. # from ... import _ops as backend -from megablocks._ops import ops as backend # type: ignore +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/arguments.py index 3962c771c90012535aab058443f89c541a2e9236..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/arguments.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/arguments.py @@ -9,7 +9,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import megablocks.grouped_gemm_util as grouped_gemm +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/common.py index ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927..2d07109702963ba48a3b94ab860807954dfd79c1 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/common.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/common.py @@ -3,7 +3,7 @@ import torch -from megablocks.layers.arguments import Arguments +from .arguments import Arguments def dtype(args: Arguments): diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py index d765bd04387a29bddd2789fae04821635b555e82..de2ed047042e438c7190ebb139b6f7f30009734c 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -3,8 +3,8 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmoe.py index 205727ff4d63f9e8dc9648acaac99a97f3394d6f..561fd7982254513e9a9b4f0514edd326b0561256 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmoe.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmoe.py @@ -6,11 +6,14 @@ import stk.ops import torch from stk import Matrix -import megablocks.ops as ops -# from megablocks.ops import ops -from megablocks.layers import common, dmlp_registry, moe, mpu -from megablocks.layers.arguments import Arguments - +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments def promote_scalar(x): return x.view(1) if not len(x.size()) else x diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/glu.py index cbe0c915c307e7f7cade3ea3ff679399635fcd81..69ef9ead5461a5b91ace107652fd8d1c3d300b1e 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/glu.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/glu.py @@ -4,11 +4,22 @@ import stk.ops import torch -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import Arguments -from megablocks.layers.mlp import ( +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( SharedMLP, SparseMLP, create_dmoe_expert_weights, diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/memory_test.py index 4acbd94f212ce906ace9de78dd3ffc6afa03f97e..74d1166931b712635131985b25a89f4ca23e576d 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/memory_test.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/memory_test.py @@ -6,7 +6,8 @@ import gc import torch import torch.distributed as dist -from megablocks.layers import arguments, dmoe +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe _TESTS = ((8, 2048, 4096, 4096, 32, 4),) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mlp.py index 6e6f4d82441e3a9c185db5bfdf686d53790dde26..0b6747b1eec2047eb9480592ca550b6c9e868f5e 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mlp.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mlp.py @@ -9,11 +9,15 @@ import stk.ops import torch from packaging import version -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, gelu, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/moe.py index 9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e..d0a4aeaacc9c86fc70944e730c53f7a55644e05e 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/moe.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/moe.py @@ -6,10 +6,27 @@ import numpy as np import torch import torch.distributed as dist -import megablocks.ops as ops -from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all _LOAD_BALANCING_LOSS = [] @@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - output = ops.sort(top_expert, self.sort_end_bit) + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) assert output is not None bin_ids, indices = output @@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module): # TODO(tgale): Does the sorted data produce a more favorable # data distribution for histogram? Or is the op parallelism # worth more? - tokens_per_expert = ops.histogram(top_expert, self.num_experts) + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins @@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module): ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) assert output is not None x = output @@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module): x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] @@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module): # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. - repeated_tokens_per_expert = ops.repeat( + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( tokens_per_expert, (mpu.hidden_sharding_degree(self.args),), ) @@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - output = ops.gather(x, indices, bin_ids, bins, self.top_k) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) assert output is not None x = output @@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module): # get all of the tokens assigned to them. # # TODO(tgale): Fuse this into the prior, local permutation. - x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) # Start the cross-device permutation asynchronously so we can # overlap communication with computation. @@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module): # for expert computation we'll do one more local permutation. The # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. - replicate_bins = ops.inclusive_cumsum( + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( parallel_tokens_per_expert.flatten(), 0, ) @@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module): ), mpu.experts_per_rank(self.args), ) - parallel_top_expert = ops.replicate( + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received, ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. - parallel_bin_ids, parallel_indices = ops.sort( + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( parallel_top_expert, self.sort_end_bit, ) @@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module): dim=0, dtype=torch.int, ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens @@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module): -1, self.args.hidden_size, ) - x = ops.sum(x.view(shape), dim=0) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mpu.py index b23213902f4567d7fdb0158cbcf5406c2b2aa601..434e143ab42bf3f83406d69e9dd1f72777716e22 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mpu.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mpu.py @@ -6,7 +6,8 @@ from typing import Optional import torch import torch.distributed as dist -from megablocks.layers.arguments import Arguments +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments class MoeParam(torch.Tensor): diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/router.py index 2c9dcd9e433322482a80d5b95afccee5c12368f8..37cb2782348d62583376f1a183c7ede83601216d 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/router.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/router.py @@ -4,8 +4,10 @@ from typing import Any import torch -from megablocks.layers import common -from megablocks.layers.arguments import Arguments +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments _ROUTER_LOGITS = [] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py index 0f62db39a2dda5642f6baa9d9b949f7c31cf6d35..5840862f88f370ace5fd49bd0612fc98d186cc49 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -3,8 +3,10 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments _REGISTRY = { 'mlp': mlp.SharedMLP, diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/__init__.py index b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c..b944080df810d0b0cfc571f3009b0098a651f9b7 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/__init__.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/__init__.py @@ -1,20 +1,20 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from megablocks.ops.binned_gather import binned_gather -from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum -from megablocks.ops.gather import gather -from megablocks.ops.histogram import histogram -from megablocks.ops.padded_gather import padded_gather -from megablocks.ops.padded_scatter import padded_scatter -from megablocks.ops.repeat import repeat -from megablocks.ops.replicate import replicate -from megablocks.ops.round_up import round_up -from megablocks.ops.scatter import scatter -from megablocks.ops.sort import sort -from megablocks.ops.sum import sum -from megablocks.ops.topology import topology +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology __all__ = [ 'binned_gather', diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 47b95301dbb35b0c3154d1ff2878d20792ce5cb2..43d267dbe2570e4a1f59fd561398668cc2bc0920 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -4,8 +4,11 @@ import torch import torch.distributed as dist -from megablocks import benchmark_util -from megablocks.layers.all_to_all import all_to_all +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from ..layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py index 89cce1b627137d7f987baca62fd6d82c5c04659a..8ae2ad8388b06db46a13f8fa46083619c44eefe2 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_gather kernel. diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py index f5ce0d6f5a890a58b89e6a501216a9822341323f..6d8654bdc718cd893f7899e9cdb6cd20544d189c 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_scatter kernel. diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/cumsum.py index 0d7c5137c900d66a959a76fa681436362a4df906..e2b7572391e20045d335cf7337246e8a9b9f57ef 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/cumsum.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/cumsum.py @@ -11,7 +11,7 @@ import torch # instructions for building the c++ operations. try: # import megablocks_ops as ops # type: ignore - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py index 41b09a1233e8a996ff1062579cd0810d095ad1e6..4edf4541dac52abab94151efb414acaa7711f8f6 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for gather kernel. diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram.py index 77e06a29163aa335dc61e8c50932d5f83322ae1d..7b3f058ec373cbba7555704fb5e4212c3cc75d9d 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py index 9de8e65281c829ba63b5c60853dfa5bb0a333988..c57b7bf8228e01237236748147368b09ffdf8072 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py index bfa7b7c2f697f9454c5381e31ee78040be5b229e..9b2c7f062071758def741bf3af2392d3a4855f89 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -7,7 +7,7 @@ import stk import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py index f272a7768dc6468e81bd2cd25294ca16a6826c08..0ffe1369d6adbce5c4a54155c636d1a4c022a41d 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_gather kernel. diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py index 9ff81dd9456a04258346266e9e85871bda56c65b..6685b0b83e7b1d92e7a772715991ead8ad94b153 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_scatter kernel. diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py index 81dde4e4b4c0f550c6027390e8de285fa4d842f7..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py index 837f07e2858d5c65c14e3f262617e7d2985ce901..6536eeeae402659a087e5c51ef9840627af56501 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/replicate.py index e570e22093339d1b71dc830edf656897c67f40e7..26daf0eede330603a4b8ea7167faf1411d07ca93 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/replicate.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/replicate.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py index a5aaafc402a651a47595c25dbf71e382903e5022..3f4bacf3422543e11abb795c279876147f8610a8 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py @@ -6,7 +6,7 @@ from typing import Any, Optional import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for scatter kernel. diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort.py index a64a4553323d8de71a1d5f1821bac50249c68197..bda3bf64283e39533c2eae3627e76bb2d0262c9f 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py index f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec..a92ff957d4c552c6e61d9279a7989795472af7b7 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _SORT_TESTS = ( (16384, torch.int32, None), diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/topology.py index b4a693a82a36c052e51afb4a02077e7849a4abf1..76a50d3164db20534b099dcb4d8487a7aef25d15 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/topology.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/topology.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..6c37e4fc559f163d315b5e00282d5a436b9defe6 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:633cb26576e8af1c8cc4a0b9e86287fb1f5c7651452d400e5ea6e9e95c49f1bd +size 11857952 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so deleted file mode 100755 index e298e56ddd77607fd0010213f29fbfd2ceb6e2b0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:eb0c2b91105c2f32f590aaa9d90ae2d6b36834bae9b35fb55c4b4fc90da56bc3 -size 11857952 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py index fb18e304d3a97b1801f4f28ccd9b74403f400f6f..c9a9faa3cee07c8b7016a71704ea26e47c460d57 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_a585153_dirty -ops = torch.ops._megablocks_a585153_dirty +from . import _megablocks_6756875_dirty +ops = torch.ops._megablocks_6756875_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py index 7c692f95852950a7b97aaaa2aa3d15e4297b6192..76037d8039cbfc2f0577275c78e4bc0be762592a 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -10,7 +10,8 @@ import torch # We import the backend operations from the megablocks package as # grouped_gemm is vendored in megablocks in this repository. # from ... import _ops as backend -from megablocks._ops import ops as backend # type: ignore +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/arguments.py index 3962c771c90012535aab058443f89c541a2e9236..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/arguments.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/arguments.py @@ -9,7 +9,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import megablocks.grouped_gemm_util as grouped_gemm +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/common.py index ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927..2d07109702963ba48a3b94ab860807954dfd79c1 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/common.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/common.py @@ -3,7 +3,7 @@ import torch -from megablocks.layers.arguments import Arguments +from .arguments import Arguments def dtype(args: Arguments): diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py index d765bd04387a29bddd2789fae04821635b555e82..de2ed047042e438c7190ebb139b6f7f30009734c 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -3,8 +3,8 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmoe.py index 205727ff4d63f9e8dc9648acaac99a97f3394d6f..561fd7982254513e9a9b4f0514edd326b0561256 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmoe.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmoe.py @@ -6,11 +6,14 @@ import stk.ops import torch from stk import Matrix -import megablocks.ops as ops -# from megablocks.ops import ops -from megablocks.layers import common, dmlp_registry, moe, mpu -from megablocks.layers.arguments import Arguments - +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments def promote_scalar(x): return x.view(1) if not len(x.size()) else x diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/glu.py index cbe0c915c307e7f7cade3ea3ff679399635fcd81..69ef9ead5461a5b91ace107652fd8d1c3d300b1e 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/glu.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/glu.py @@ -4,11 +4,22 @@ import stk.ops import torch -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import Arguments -from megablocks.layers.mlp import ( +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( SharedMLP, SparseMLP, create_dmoe_expert_weights, diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/memory_test.py index 4acbd94f212ce906ace9de78dd3ffc6afa03f97e..74d1166931b712635131985b25a89f4ca23e576d 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/memory_test.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/memory_test.py @@ -6,7 +6,8 @@ import gc import torch import torch.distributed as dist -from megablocks.layers import arguments, dmoe +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe _TESTS = ((8, 2048, 4096, 4096, 32, 4),) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mlp.py index 6e6f4d82441e3a9c185db5bfdf686d53790dde26..0b6747b1eec2047eb9480592ca550b6c9e868f5e 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mlp.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mlp.py @@ -9,11 +9,15 @@ import stk.ops import torch from packaging import version -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, gelu, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/moe.py index 9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e..d0a4aeaacc9c86fc70944e730c53f7a55644e05e 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/moe.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/moe.py @@ -6,10 +6,27 @@ import numpy as np import torch import torch.distributed as dist -import megablocks.ops as ops -from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all _LOAD_BALANCING_LOSS = [] @@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - output = ops.sort(top_expert, self.sort_end_bit) + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) assert output is not None bin_ids, indices = output @@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module): # TODO(tgale): Does the sorted data produce a more favorable # data distribution for histogram? Or is the op parallelism # worth more? - tokens_per_expert = ops.histogram(top_expert, self.num_experts) + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins @@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module): ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) assert output is not None x = output @@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module): x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] @@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module): # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. - repeated_tokens_per_expert = ops.repeat( + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( tokens_per_expert, (mpu.hidden_sharding_degree(self.args),), ) @@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - output = ops.gather(x, indices, bin_ids, bins, self.top_k) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) assert output is not None x = output @@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module): # get all of the tokens assigned to them. # # TODO(tgale): Fuse this into the prior, local permutation. - x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) # Start the cross-device permutation asynchronously so we can # overlap communication with computation. @@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module): # for expert computation we'll do one more local permutation. The # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. - replicate_bins = ops.inclusive_cumsum( + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( parallel_tokens_per_expert.flatten(), 0, ) @@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module): ), mpu.experts_per_rank(self.args), ) - parallel_top_expert = ops.replicate( + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received, ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. - parallel_bin_ids, parallel_indices = ops.sort( + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( parallel_top_expert, self.sort_end_bit, ) @@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module): dim=0, dtype=torch.int, ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens @@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module): -1, self.args.hidden_size, ) - x = ops.sum(x.view(shape), dim=0) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mpu.py index b23213902f4567d7fdb0158cbcf5406c2b2aa601..434e143ab42bf3f83406d69e9dd1f72777716e22 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mpu.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mpu.py @@ -6,7 +6,8 @@ from typing import Optional import torch import torch.distributed as dist -from megablocks.layers.arguments import Arguments +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments class MoeParam(torch.Tensor): diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/router.py index 2c9dcd9e433322482a80d5b95afccee5c12368f8..37cb2782348d62583376f1a183c7ede83601216d 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/router.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/router.py @@ -4,8 +4,10 @@ from typing import Any import torch -from megablocks.layers import common -from megablocks.layers.arguments import Arguments +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments _ROUTER_LOGITS = [] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py index 0f62db39a2dda5642f6baa9d9b949f7c31cf6d35..5840862f88f370ace5fd49bd0612fc98d186cc49 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -3,8 +3,10 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments _REGISTRY = { 'mlp': mlp.SharedMLP, diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/__init__.py index b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c..b944080df810d0b0cfc571f3009b0098a651f9b7 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/__init__.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/__init__.py @@ -1,20 +1,20 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from megablocks.ops.binned_gather import binned_gather -from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum -from megablocks.ops.gather import gather -from megablocks.ops.histogram import histogram -from megablocks.ops.padded_gather import padded_gather -from megablocks.ops.padded_scatter import padded_scatter -from megablocks.ops.repeat import repeat -from megablocks.ops.replicate import replicate -from megablocks.ops.round_up import round_up -from megablocks.ops.scatter import scatter -from megablocks.ops.sort import sort -from megablocks.ops.sum import sum -from megablocks.ops.topology import topology +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology __all__ = [ 'binned_gather', diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 47b95301dbb35b0c3154d1ff2878d20792ce5cb2..43d267dbe2570e4a1f59fd561398668cc2bc0920 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -4,8 +4,11 @@ import torch import torch.distributed as dist -from megablocks import benchmark_util -from megablocks.layers.all_to_all import all_to_all +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from ..layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py index 89cce1b627137d7f987baca62fd6d82c5c04659a..8ae2ad8388b06db46a13f8fa46083619c44eefe2 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_gather kernel. diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py index f5ce0d6f5a890a58b89e6a501216a9822341323f..6d8654bdc718cd893f7899e9cdb6cd20544d189c 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_scatter kernel. diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/cumsum.py index 0d7c5137c900d66a959a76fa681436362a4df906..e2b7572391e20045d335cf7337246e8a9b9f57ef 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/cumsum.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/cumsum.py @@ -11,7 +11,7 @@ import torch # instructions for building the c++ operations. try: # import megablocks_ops as ops # type: ignore - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py index 41b09a1233e8a996ff1062579cd0810d095ad1e6..4edf4541dac52abab94151efb414acaa7711f8f6 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for gather kernel. diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram.py index 77e06a29163aa335dc61e8c50932d5f83322ae1d..7b3f058ec373cbba7555704fb5e4212c3cc75d9d 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py index 9de8e65281c829ba63b5c60853dfa5bb0a333988..c57b7bf8228e01237236748147368b09ffdf8072 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py index bfa7b7c2f697f9454c5381e31ee78040be5b229e..9b2c7f062071758def741bf3af2392d3a4855f89 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -7,7 +7,7 @@ import stk import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py index f272a7768dc6468e81bd2cd25294ca16a6826c08..0ffe1369d6adbce5c4a54155c636d1a4c022a41d 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_gather kernel. diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py index 9ff81dd9456a04258346266e9e85871bda56c65b..6685b0b83e7b1d92e7a772715991ead8ad94b153 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_scatter kernel. diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py index 81dde4e4b4c0f550c6027390e8de285fa4d842f7..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py index 837f07e2858d5c65c14e3f262617e7d2985ce901..6536eeeae402659a087e5c51ef9840627af56501 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/replicate.py index e570e22093339d1b71dc830edf656897c67f40e7..26daf0eede330603a4b8ea7167faf1411d07ca93 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/replicate.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/replicate.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py index a5aaafc402a651a47595c25dbf71e382903e5022..3f4bacf3422543e11abb795c279876147f8610a8 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py @@ -6,7 +6,7 @@ from typing import Any, Optional import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for scatter kernel. diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort.py index a64a4553323d8de71a1d5f1821bac50249c68197..bda3bf64283e39533c2eae3627e76bb2d0262c9f 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py index f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec..a92ff957d4c552c6e61d9279a7989795472af7b7 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _SORT_TESTS = ( (16384, torch.int32, None), diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/topology.py index b4a693a82a36c052e51afb4a02077e7849a4abf1..76a50d3164db20534b099dcb4d8487a7aef25d15 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/topology.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/topology.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..a6ddf9c7634f8b8889d37a6f6e774c1ae01f3291 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69198313a732d74196039b0c350fe1d4e14448abdf9b4c2a6215eabd9b0e171c +size 11923704 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so deleted file mode 100755 index bc0ec44d4aeb316188bc43b4704fef6028c9af92..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:bd817eed5069e786933346cb2bb5ab6f586878ae80647191932336dec3295c96 -size 11923704 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py index fb18e304d3a97b1801f4f28ccd9b74403f400f6f..c9a9faa3cee07c8b7016a71704ea26e47c460d57 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_a585153_dirty -ops = torch.ops._megablocks_a585153_dirty +from . import _megablocks_6756875_dirty +ops = torch.ops._megablocks_6756875_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py index 7c692f95852950a7b97aaaa2aa3d15e4297b6192..76037d8039cbfc2f0577275c78e4bc0be762592a 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -10,7 +10,8 @@ import torch # We import the backend operations from the megablocks package as # grouped_gemm is vendored in megablocks in this repository. # from ... import _ops as backend -from megablocks._ops import ops as backend # type: ignore +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/arguments.py index 3962c771c90012535aab058443f89c541a2e9236..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/arguments.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/arguments.py @@ -9,7 +9,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import megablocks.grouped_gemm_util as grouped_gemm +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/common.py index ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927..2d07109702963ba48a3b94ab860807954dfd79c1 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/common.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/common.py @@ -3,7 +3,7 @@ import torch -from megablocks.layers.arguments import Arguments +from .arguments import Arguments def dtype(args: Arguments): diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py index d765bd04387a29bddd2789fae04821635b555e82..de2ed047042e438c7190ebb139b6f7f30009734c 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -3,8 +3,8 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmoe.py index 205727ff4d63f9e8dc9648acaac99a97f3394d6f..561fd7982254513e9a9b4f0514edd326b0561256 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmoe.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmoe.py @@ -6,11 +6,14 @@ import stk.ops import torch from stk import Matrix -import megablocks.ops as ops -# from megablocks.ops import ops -from megablocks.layers import common, dmlp_registry, moe, mpu -from megablocks.layers.arguments import Arguments - +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments def promote_scalar(x): return x.view(1) if not len(x.size()) else x diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/glu.py index cbe0c915c307e7f7cade3ea3ff679399635fcd81..69ef9ead5461a5b91ace107652fd8d1c3d300b1e 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/glu.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/glu.py @@ -4,11 +4,22 @@ import stk.ops import torch -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import Arguments -from megablocks.layers.mlp import ( +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( SharedMLP, SparseMLP, create_dmoe_expert_weights, diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/memory_test.py index 4acbd94f212ce906ace9de78dd3ffc6afa03f97e..74d1166931b712635131985b25a89f4ca23e576d 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/memory_test.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/memory_test.py @@ -6,7 +6,8 @@ import gc import torch import torch.distributed as dist -from megablocks.layers import arguments, dmoe +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe _TESTS = ((8, 2048, 4096, 4096, 32, 4),) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mlp.py index 6e6f4d82441e3a9c185db5bfdf686d53790dde26..0b6747b1eec2047eb9480592ca550b6c9e868f5e 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mlp.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mlp.py @@ -9,11 +9,15 @@ import stk.ops import torch from packaging import version -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, gelu, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/moe.py index 9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e..d0a4aeaacc9c86fc70944e730c53f7a55644e05e 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/moe.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/moe.py @@ -6,10 +6,27 @@ import numpy as np import torch import torch.distributed as dist -import megablocks.ops as ops -from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all _LOAD_BALANCING_LOSS = [] @@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - output = ops.sort(top_expert, self.sort_end_bit) + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) assert output is not None bin_ids, indices = output @@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module): # TODO(tgale): Does the sorted data produce a more favorable # data distribution for histogram? Or is the op parallelism # worth more? - tokens_per_expert = ops.histogram(top_expert, self.num_experts) + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins @@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module): ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) assert output is not None x = output @@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module): x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] @@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module): # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. - repeated_tokens_per_expert = ops.repeat( + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( tokens_per_expert, (mpu.hidden_sharding_degree(self.args),), ) @@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - output = ops.gather(x, indices, bin_ids, bins, self.top_k) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) assert output is not None x = output @@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module): # get all of the tokens assigned to them. # # TODO(tgale): Fuse this into the prior, local permutation. - x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) # Start the cross-device permutation asynchronously so we can # overlap communication with computation. @@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module): # for expert computation we'll do one more local permutation. The # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. - replicate_bins = ops.inclusive_cumsum( + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( parallel_tokens_per_expert.flatten(), 0, ) @@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module): ), mpu.experts_per_rank(self.args), ) - parallel_top_expert = ops.replicate( + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received, ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. - parallel_bin_ids, parallel_indices = ops.sort( + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( parallel_top_expert, self.sort_end_bit, ) @@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module): dim=0, dtype=torch.int, ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens @@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module): -1, self.args.hidden_size, ) - x = ops.sum(x.view(shape), dim=0) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mpu.py index b23213902f4567d7fdb0158cbcf5406c2b2aa601..434e143ab42bf3f83406d69e9dd1f72777716e22 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mpu.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mpu.py @@ -6,7 +6,8 @@ from typing import Optional import torch import torch.distributed as dist -from megablocks.layers.arguments import Arguments +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments class MoeParam(torch.Tensor): diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/router.py index 2c9dcd9e433322482a80d5b95afccee5c12368f8..37cb2782348d62583376f1a183c7ede83601216d 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/router.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/router.py @@ -4,8 +4,10 @@ from typing import Any import torch -from megablocks.layers import common -from megablocks.layers.arguments import Arguments +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments _ROUTER_LOGITS = [] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py index 0f62db39a2dda5642f6baa9d9b949f7c31cf6d35..5840862f88f370ace5fd49bd0612fc98d186cc49 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -3,8 +3,10 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments _REGISTRY = { 'mlp': mlp.SharedMLP, diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/__init__.py index b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c..b944080df810d0b0cfc571f3009b0098a651f9b7 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/__init__.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/__init__.py @@ -1,20 +1,20 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from megablocks.ops.binned_gather import binned_gather -from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum -from megablocks.ops.gather import gather -from megablocks.ops.histogram import histogram -from megablocks.ops.padded_gather import padded_gather -from megablocks.ops.padded_scatter import padded_scatter -from megablocks.ops.repeat import repeat -from megablocks.ops.replicate import replicate -from megablocks.ops.round_up import round_up -from megablocks.ops.scatter import scatter -from megablocks.ops.sort import sort -from megablocks.ops.sum import sum -from megablocks.ops.topology import topology +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology __all__ = [ 'binned_gather', diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 47b95301dbb35b0c3154d1ff2878d20792ce5cb2..43d267dbe2570e4a1f59fd561398668cc2bc0920 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -4,8 +4,11 @@ import torch import torch.distributed as dist -from megablocks import benchmark_util -from megablocks.layers.all_to_all import all_to_all +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from ..layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py index 89cce1b627137d7f987baca62fd6d82c5c04659a..8ae2ad8388b06db46a13f8fa46083619c44eefe2 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_gather kernel. diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py index f5ce0d6f5a890a58b89e6a501216a9822341323f..6d8654bdc718cd893f7899e9cdb6cd20544d189c 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_scatter kernel. diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/cumsum.py index 0d7c5137c900d66a959a76fa681436362a4df906..e2b7572391e20045d335cf7337246e8a9b9f57ef 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/cumsum.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/cumsum.py @@ -11,7 +11,7 @@ import torch # instructions for building the c++ operations. try: # import megablocks_ops as ops # type: ignore - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py index 41b09a1233e8a996ff1062579cd0810d095ad1e6..4edf4541dac52abab94151efb414acaa7711f8f6 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for gather kernel. diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram.py index 77e06a29163aa335dc61e8c50932d5f83322ae1d..7b3f058ec373cbba7555704fb5e4212c3cc75d9d 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py index 9de8e65281c829ba63b5c60853dfa5bb0a333988..c57b7bf8228e01237236748147368b09ffdf8072 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py index bfa7b7c2f697f9454c5381e31ee78040be5b229e..9b2c7f062071758def741bf3af2392d3a4855f89 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -7,7 +7,7 @@ import stk import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py index f272a7768dc6468e81bd2cd25294ca16a6826c08..0ffe1369d6adbce5c4a54155c636d1a4c022a41d 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_gather kernel. diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py index 9ff81dd9456a04258346266e9e85871bda56c65b..6685b0b83e7b1d92e7a772715991ead8ad94b153 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_scatter kernel. diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py index 81dde4e4b4c0f550c6027390e8de285fa4d842f7..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py index 837f07e2858d5c65c14e3f262617e7d2985ce901..6536eeeae402659a087e5c51ef9840627af56501 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/replicate.py index e570e22093339d1b71dc830edf656897c67f40e7..26daf0eede330603a4b8ea7167faf1411d07ca93 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/replicate.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/replicate.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py index a5aaafc402a651a47595c25dbf71e382903e5022..3f4bacf3422543e11abb795c279876147f8610a8 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py @@ -6,7 +6,7 @@ from typing import Any, Optional import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for scatter kernel. diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort.py index a64a4553323d8de71a1d5f1821bac50249c68197..bda3bf64283e39533c2eae3627e76bb2d0262c9f 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py index f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec..a92ff957d4c552c6e61d9279a7989795472af7b7 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _SORT_TESTS = ( (16384, torch.int32, None), diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/topology.py index b4a693a82a36c052e51afb4a02077e7849a4abf1..76a50d3164db20534b099dcb4d8487a7aef25d15 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/topology.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/topology.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..2edf877dce4fdf86bffbaa772e48aeb7f456199b --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f535bc4e3bf6c970117d2dbf7f7b4dcdabf1fea035590851c2a558ca0e95497f +size 10517848 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so deleted file mode 100755 index b774ea5c271fea98a48c9c21d7fb1a824d73e065..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c6edc24e0f999c3ecb6120715db285792c9f19dbe387348fe0e0f25c72e97138 -size 10517848 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py index fb18e304d3a97b1801f4f28ccd9b74403f400f6f..c9a9faa3cee07c8b7016a71704ea26e47c460d57 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_a585153_dirty -ops = torch.ops._megablocks_a585153_dirty +from . import _megablocks_6756875_dirty +ops = torch.ops._megablocks_6756875_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py index 7c692f95852950a7b97aaaa2aa3d15e4297b6192..76037d8039cbfc2f0577275c78e4bc0be762592a 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -10,7 +10,8 @@ import torch # We import the backend operations from the megablocks package as # grouped_gemm is vendored in megablocks in this repository. # from ... import _ops as backend -from megablocks._ops import ops as backend # type: ignore +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py index 3962c771c90012535aab058443f89c541a2e9236..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py @@ -9,7 +9,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import megablocks.grouped_gemm_util as grouped_gemm +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/common.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/common.py index ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927..2d07109702963ba48a3b94ab860807954dfd79c1 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/common.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/common.py @@ -3,7 +3,7 @@ import torch -from megablocks.layers.arguments import Arguments +from .arguments import Arguments def dtype(args: Arguments): diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py index d765bd04387a29bddd2789fae04821635b555e82..de2ed047042e438c7190ebb139b6f7f30009734c 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -3,8 +3,8 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py index 205727ff4d63f9e8dc9648acaac99a97f3394d6f..561fd7982254513e9a9b4f0514edd326b0561256 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py @@ -6,11 +6,14 @@ import stk.ops import torch from stk import Matrix -import megablocks.ops as ops -# from megablocks.ops import ops -from megablocks.layers import common, dmlp_registry, moe, mpu -from megablocks.layers.arguments import Arguments - +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments def promote_scalar(x): return x.view(1) if not len(x.size()) else x diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py index cbe0c915c307e7f7cade3ea3ff679399635fcd81..69ef9ead5461a5b91ace107652fd8d1c3d300b1e 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py @@ -4,11 +4,22 @@ import stk.ops import torch -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import Arguments -from megablocks.layers.mlp import ( +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( SharedMLP, SparseMLP, create_dmoe_expert_weights, diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py index 4acbd94f212ce906ace9de78dd3ffc6afa03f97e..74d1166931b712635131985b25a89f4ca23e576d 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py @@ -6,7 +6,8 @@ import gc import torch import torch.distributed as dist -from megablocks.layers import arguments, dmoe +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe _TESTS = ((8, 2048, 4096, 4096, 32, 4),) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py index 6e6f4d82441e3a9c185db5bfdf686d53790dde26..0b6747b1eec2047eb9480592ca550b6c9e868f5e 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py @@ -9,11 +9,15 @@ import stk.ops import torch from packaging import version -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, gelu, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py index 9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e..d0a4aeaacc9c86fc70944e730c53f7a55644e05e 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py @@ -6,10 +6,27 @@ import numpy as np import torch import torch.distributed as dist -import megablocks.ops as ops -from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all _LOAD_BALANCING_LOSS = [] @@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - output = ops.sort(top_expert, self.sort_end_bit) + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) assert output is not None bin_ids, indices = output @@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module): # TODO(tgale): Does the sorted data produce a more favorable # data distribution for histogram? Or is the op parallelism # worth more? - tokens_per_expert = ops.histogram(top_expert, self.num_experts) + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins @@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module): ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) assert output is not None x = output @@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module): x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] @@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module): # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. - repeated_tokens_per_expert = ops.repeat( + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( tokens_per_expert, (mpu.hidden_sharding_degree(self.args),), ) @@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - output = ops.gather(x, indices, bin_ids, bins, self.top_k) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) assert output is not None x = output @@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module): # get all of the tokens assigned to them. # # TODO(tgale): Fuse this into the prior, local permutation. - x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) # Start the cross-device permutation asynchronously so we can # overlap communication with computation. @@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module): # for expert computation we'll do one more local permutation. The # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. - replicate_bins = ops.inclusive_cumsum( + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( parallel_tokens_per_expert.flatten(), 0, ) @@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module): ), mpu.experts_per_rank(self.args), ) - parallel_top_expert = ops.replicate( + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received, ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. - parallel_bin_ids, parallel_indices = ops.sort( + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( parallel_top_expert, self.sort_end_bit, ) @@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module): dim=0, dtype=torch.int, ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens @@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module): -1, self.args.hidden_size, ) - x = ops.sum(x.view(shape), dim=0) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py index b23213902f4567d7fdb0158cbcf5406c2b2aa601..434e143ab42bf3f83406d69e9dd1f72777716e22 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py @@ -6,7 +6,8 @@ from typing import Optional import torch import torch.distributed as dist -from megablocks.layers.arguments import Arguments +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments class MoeParam(torch.Tensor): diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/router.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/router.py index 2c9dcd9e433322482a80d5b95afccee5c12368f8..37cb2782348d62583376f1a183c7ede83601216d 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/router.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/router.py @@ -4,8 +4,10 @@ from typing import Any import torch -from megablocks.layers import common -from megablocks.layers.arguments import Arguments +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments _ROUTER_LOGITS = [] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py index 0f62db39a2dda5642f6baa9d9b949f7c31cf6d35..5840862f88f370ace5fd49bd0612fc98d186cc49 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -3,8 +3,10 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments _REGISTRY = { 'mlp': mlp.SharedMLP, diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py index b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c..b944080df810d0b0cfc571f3009b0098a651f9b7 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py @@ -1,20 +1,20 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from megablocks.ops.binned_gather import binned_gather -from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum -from megablocks.ops.gather import gather -from megablocks.ops.histogram import histogram -from megablocks.ops.padded_gather import padded_gather -from megablocks.ops.padded_scatter import padded_scatter -from megablocks.ops.repeat import repeat -from megablocks.ops.replicate import replicate -from megablocks.ops.round_up import round_up -from megablocks.ops.scatter import scatter -from megablocks.ops.sort import sort -from megablocks.ops.sum import sum -from megablocks.ops.topology import topology +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology __all__ = [ 'binned_gather', diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 47b95301dbb35b0c3154d1ff2878d20792ce5cb2..43d267dbe2570e4a1f59fd561398668cc2bc0920 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -4,8 +4,11 @@ import torch import torch.distributed as dist -from megablocks import benchmark_util -from megablocks.layers.all_to_all import all_to_all +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from ..layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py index 89cce1b627137d7f987baca62fd6d82c5c04659a..8ae2ad8388b06db46a13f8fa46083619c44eefe2 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_gather kernel. diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py index f5ce0d6f5a890a58b89e6a501216a9822341323f..6d8654bdc718cd893f7899e9cdb6cd20544d189c 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_scatter kernel. diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py index 0d7c5137c900d66a959a76fa681436362a4df906..e2b7572391e20045d335cf7337246e8a9b9f57ef 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py @@ -11,7 +11,7 @@ import torch # instructions for building the c++ operations. try: # import megablocks_ops as ops # type: ignore - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py index 41b09a1233e8a996ff1062579cd0810d095ad1e6..4edf4541dac52abab94151efb414acaa7711f8f6 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for gather kernel. diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py index 77e06a29163aa335dc61e8c50932d5f83322ae1d..7b3f058ec373cbba7555704fb5e4212c3cc75d9d 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py index 9de8e65281c829ba63b5c60853dfa5bb0a333988..c57b7bf8228e01237236748147368b09ffdf8072 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py index bfa7b7c2f697f9454c5381e31ee78040be5b229e..9b2c7f062071758def741bf3af2392d3a4855f89 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -7,7 +7,7 @@ import stk import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py index f272a7768dc6468e81bd2cd25294ca16a6826c08..0ffe1369d6adbce5c4a54155c636d1a4c022a41d 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_gather kernel. diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py index 9ff81dd9456a04258346266e9e85871bda56c65b..6685b0b83e7b1d92e7a772715991ead8ad94b153 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_scatter kernel. diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py index 81dde4e4b4c0f550c6027390e8de285fa4d842f7..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py index 837f07e2858d5c65c14e3f262617e7d2985ce901..6536eeeae402659a087e5c51ef9840627af56501 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py index e570e22093339d1b71dc830edf656897c67f40e7..26daf0eede330603a4b8ea7167faf1411d07ca93 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py index a5aaafc402a651a47595c25dbf71e382903e5022..3f4bacf3422543e11abb795c279876147f8610a8 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py @@ -6,7 +6,7 @@ from typing import Any, Optional import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for scatter kernel. diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py index a64a4553323d8de71a1d5f1821bac50249c68197..bda3bf64283e39533c2eae3627e76bb2d0262c9f 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py index f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec..a92ff957d4c552c6e61d9279a7989795472af7b7 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _SORT_TESTS = ( (16384, torch.int32, None), diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py index b4a693a82a36c052e51afb4a02077e7849a4abf1..76a50d3164db20534b099dcb4d8487a7aef25d15 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..229aa059ff8f5d788ce34d0d0707467f9051e603 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4708535fe6c5c8d59883ad81b5ce6b0083c3ed0ae71f14749ac2415899797f36 +size 11931112 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so deleted file mode 100755 index 6485829ebc3f4f120eac78ed896d64000416c3ac..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c8ad568311ee3b60bd6c53500a01213c767afd6d455b10f0b4e23a1a162839d8 -size 11931112 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py index fb18e304d3a97b1801f4f28ccd9b74403f400f6f..c9a9faa3cee07c8b7016a71704ea26e47c460d57 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_a585153_dirty -ops = torch.ops._megablocks_a585153_dirty +from . import _megablocks_6756875_dirty +ops = torch.ops._megablocks_6756875_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py index 7c692f95852950a7b97aaaa2aa3d15e4297b6192..76037d8039cbfc2f0577275c78e4bc0be762592a 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -10,7 +10,8 @@ import torch # We import the backend operations from the megablocks package as # grouped_gemm is vendored in megablocks in this repository. # from ... import _ops as backend -from megablocks._ops import ops as backend # type: ignore +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py index 3962c771c90012535aab058443f89c541a2e9236..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py @@ -9,7 +9,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import megablocks.grouped_gemm_util as grouped_gemm +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/common.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/common.py index ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927..2d07109702963ba48a3b94ab860807954dfd79c1 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/common.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/common.py @@ -3,7 +3,7 @@ import torch -from megablocks.layers.arguments import Arguments +from .arguments import Arguments def dtype(args: Arguments): diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py index d765bd04387a29bddd2789fae04821635b555e82..de2ed047042e438c7190ebb139b6f7f30009734c 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -3,8 +3,8 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py index 205727ff4d63f9e8dc9648acaac99a97f3394d6f..561fd7982254513e9a9b4f0514edd326b0561256 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py @@ -6,11 +6,14 @@ import stk.ops import torch from stk import Matrix -import megablocks.ops as ops -# from megablocks.ops import ops -from megablocks.layers import common, dmlp_registry, moe, mpu -from megablocks.layers.arguments import Arguments - +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments def promote_scalar(x): return x.view(1) if not len(x.size()) else x diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py index cbe0c915c307e7f7cade3ea3ff679399635fcd81..69ef9ead5461a5b91ace107652fd8d1c3d300b1e 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py @@ -4,11 +4,22 @@ import stk.ops import torch -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import Arguments -from megablocks.layers.mlp import ( +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( SharedMLP, SparseMLP, create_dmoe_expert_weights, diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py index 4acbd94f212ce906ace9de78dd3ffc6afa03f97e..74d1166931b712635131985b25a89f4ca23e576d 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py @@ -6,7 +6,8 @@ import gc import torch import torch.distributed as dist -from megablocks.layers import arguments, dmoe +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe _TESTS = ((8, 2048, 4096, 4096, 32, 4),) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py index 6e6f4d82441e3a9c185db5bfdf686d53790dde26..0b6747b1eec2047eb9480592ca550b6c9e868f5e 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py @@ -9,11 +9,15 @@ import stk.ops import torch from packaging import version -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, gelu, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py index 9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e..d0a4aeaacc9c86fc70944e730c53f7a55644e05e 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py @@ -6,10 +6,27 @@ import numpy as np import torch import torch.distributed as dist -import megablocks.ops as ops -from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all _LOAD_BALANCING_LOSS = [] @@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - output = ops.sort(top_expert, self.sort_end_bit) + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) assert output is not None bin_ids, indices = output @@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module): # TODO(tgale): Does the sorted data produce a more favorable # data distribution for histogram? Or is the op parallelism # worth more? - tokens_per_expert = ops.histogram(top_expert, self.num_experts) + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins @@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module): ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) assert output is not None x = output @@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module): x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] @@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module): # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. - repeated_tokens_per_expert = ops.repeat( + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( tokens_per_expert, (mpu.hidden_sharding_degree(self.args),), ) @@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - output = ops.gather(x, indices, bin_ids, bins, self.top_k) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) assert output is not None x = output @@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module): # get all of the tokens assigned to them. # # TODO(tgale): Fuse this into the prior, local permutation. - x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) # Start the cross-device permutation asynchronously so we can # overlap communication with computation. @@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module): # for expert computation we'll do one more local permutation. The # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. - replicate_bins = ops.inclusive_cumsum( + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( parallel_tokens_per_expert.flatten(), 0, ) @@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module): ), mpu.experts_per_rank(self.args), ) - parallel_top_expert = ops.replicate( + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received, ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. - parallel_bin_ids, parallel_indices = ops.sort( + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( parallel_top_expert, self.sort_end_bit, ) @@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module): dim=0, dtype=torch.int, ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens @@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module): -1, self.args.hidden_size, ) - x = ops.sum(x.view(shape), dim=0) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py index b23213902f4567d7fdb0158cbcf5406c2b2aa601..434e143ab42bf3f83406d69e9dd1f72777716e22 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py @@ -6,7 +6,8 @@ from typing import Optional import torch import torch.distributed as dist -from megablocks.layers.arguments import Arguments +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments class MoeParam(torch.Tensor): diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/router.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/router.py index 2c9dcd9e433322482a80d5b95afccee5c12368f8..37cb2782348d62583376f1a183c7ede83601216d 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/router.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/router.py @@ -4,8 +4,10 @@ from typing import Any import torch -from megablocks.layers import common -from megablocks.layers.arguments import Arguments +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments _ROUTER_LOGITS = [] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py index 0f62db39a2dda5642f6baa9d9b949f7c31cf6d35..5840862f88f370ace5fd49bd0612fc98d186cc49 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -3,8 +3,10 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments _REGISTRY = { 'mlp': mlp.SharedMLP, diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py index b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c..b944080df810d0b0cfc571f3009b0098a651f9b7 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py @@ -1,20 +1,20 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from megablocks.ops.binned_gather import binned_gather -from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum -from megablocks.ops.gather import gather -from megablocks.ops.histogram import histogram -from megablocks.ops.padded_gather import padded_gather -from megablocks.ops.padded_scatter import padded_scatter -from megablocks.ops.repeat import repeat -from megablocks.ops.replicate import replicate -from megablocks.ops.round_up import round_up -from megablocks.ops.scatter import scatter -from megablocks.ops.sort import sort -from megablocks.ops.sum import sum -from megablocks.ops.topology import topology +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology __all__ = [ 'binned_gather', diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 47b95301dbb35b0c3154d1ff2878d20792ce5cb2..43d267dbe2570e4a1f59fd561398668cc2bc0920 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -4,8 +4,11 @@ import torch import torch.distributed as dist -from megablocks import benchmark_util -from megablocks.layers.all_to_all import all_to_all +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from ..layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py index 89cce1b627137d7f987baca62fd6d82c5c04659a..8ae2ad8388b06db46a13f8fa46083619c44eefe2 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_gather kernel. diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py index f5ce0d6f5a890a58b89e6a501216a9822341323f..6d8654bdc718cd893f7899e9cdb6cd20544d189c 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_scatter kernel. diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py index 0d7c5137c900d66a959a76fa681436362a4df906..e2b7572391e20045d335cf7337246e8a9b9f57ef 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py @@ -11,7 +11,7 @@ import torch # instructions for building the c++ operations. try: # import megablocks_ops as ops # type: ignore - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py index 41b09a1233e8a996ff1062579cd0810d095ad1e6..4edf4541dac52abab94151efb414acaa7711f8f6 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for gather kernel. diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py index 77e06a29163aa335dc61e8c50932d5f83322ae1d..7b3f058ec373cbba7555704fb5e4212c3cc75d9d 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py index 9de8e65281c829ba63b5c60853dfa5bb0a333988..c57b7bf8228e01237236748147368b09ffdf8072 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py index bfa7b7c2f697f9454c5381e31ee78040be5b229e..9b2c7f062071758def741bf3af2392d3a4855f89 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -7,7 +7,7 @@ import stk import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py index f272a7768dc6468e81bd2cd25294ca16a6826c08..0ffe1369d6adbce5c4a54155c636d1a4c022a41d 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_gather kernel. diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py index 9ff81dd9456a04258346266e9e85871bda56c65b..6685b0b83e7b1d92e7a772715991ead8ad94b153 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_scatter kernel. diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py index 81dde4e4b4c0f550c6027390e8de285fa4d842f7..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py index 837f07e2858d5c65c14e3f262617e7d2985ce901..6536eeeae402659a087e5c51ef9840627af56501 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py index e570e22093339d1b71dc830edf656897c67f40e7..26daf0eede330603a4b8ea7167faf1411d07ca93 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py index a5aaafc402a651a47595c25dbf71e382903e5022..3f4bacf3422543e11abb795c279876147f8610a8 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py @@ -6,7 +6,7 @@ from typing import Any, Optional import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for scatter kernel. diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py index a64a4553323d8de71a1d5f1821bac50249c68197..bda3bf64283e39533c2eae3627e76bb2d0262c9f 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py index f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec..a92ff957d4c552c6e61d9279a7989795472af7b7 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _SORT_TESTS = ( (16384, torch.int32, None), diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py index b4a693a82a36c052e51afb4a02077e7849a4abf1..76a50d3164db20534b099dcb4d8487a7aef25d15 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..50000f9b7cb45cf014efe881e2c53c60f3497663 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae05d16e382a38c3a908f350e2347206319b5a0a9f9cd2447464b6409dc10fce +size 17892656 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so deleted file mode 100755 index d2b18b163b48d73368dc51792783bbddb2c00622..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ca88225928254532529d6a45b819b44f4932085ac90d966ea31a824777c4d581 -size 17892656 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py index fb18e304d3a97b1801f4f28ccd9b74403f400f6f..c9a9faa3cee07c8b7016a71704ea26e47c460d57 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_a585153_dirty -ops = torch.ops._megablocks_a585153_dirty +from . import _megablocks_6756875_dirty +ops = torch.ops._megablocks_6756875_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py index 7c692f95852950a7b97aaaa2aa3d15e4297b6192..76037d8039cbfc2f0577275c78e4bc0be762592a 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -10,7 +10,8 @@ import torch # We import the backend operations from the megablocks package as # grouped_gemm is vendored in megablocks in this repository. # from ... import _ops as backend -from megablocks._ops import ops as backend # type: ignore +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore def _allocate_output(a, b, batch_sizes, trans_a, trans_b): assert not (trans_a and trans_b) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/arguments.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/arguments.py index 3962c771c90012535aab058443f89c541a2e9236..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/arguments.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/arguments.py @@ -9,7 +9,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import megablocks.grouped_gemm_util as grouped_gemm +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/common.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/common.py index ee30e79374a8c6e9e49f8e5b1eccc782ae6cb927..2d07109702963ba48a3b94ab860807954dfd79c1 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/common.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/common.py @@ -3,7 +3,7 @@ import torch -from megablocks.layers.arguments import Arguments +from .arguments import Arguments def dtype(args: Arguments): diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmlp_registry.py index d765bd04387a29bddd2789fae04821635b555e82..de2ed047042e438c7190ebb139b6f7f30009734c 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmlp_registry.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmlp_registry.py @@ -3,8 +3,8 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmoe.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmoe.py index 205727ff4d63f9e8dc9648acaac99a97f3394d6f..561fd7982254513e9a9b4f0514edd326b0561256 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmoe.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmoe.py @@ -6,11 +6,14 @@ import stk.ops import torch from stk import Matrix -import megablocks.ops as ops -# from megablocks.ops import ops -from megablocks.layers import common, dmlp_registry, moe, mpu -from megablocks.layers.arguments import Arguments - +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments def promote_scalar(x): return x.view(1) if not len(x.size()) else x diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/glu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/glu.py index cbe0c915c307e7f7cade3ea3ff679399635fcd81..69ef9ead5461a5b91ace107652fd8d1c3d300b1e 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/glu.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/glu.py @@ -4,11 +4,22 @@ import stk.ops import torch -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import Arguments -from megablocks.layers.mlp import ( +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( SharedMLP, SparseMLP, create_dmoe_expert_weights, diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/memory_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/memory_test.py index 4acbd94f212ce906ace9de78dd3ffc6afa03f97e..74d1166931b712635131985b25a89f4ca23e576d 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/memory_test.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/memory_test.py @@ -6,7 +6,8 @@ import gc import torch import torch.distributed as dist -from megablocks.layers import arguments, dmoe +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe _TESTS = ((8, 2048, 4096, 4096, 32, 4),) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mlp.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mlp.py index 6e6f4d82441e3a9c185db5bfdf686d53790dde26..0b6747b1eec2047eb9480592ca550b6c9e868f5e 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mlp.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mlp.py @@ -9,11 +9,15 @@ import stk.ops import torch from packaging import version -from megablocks import grouped_gemm_util as gg -from megablocks.layers import common, gelu, mpu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/moe.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/moe.py index 9ba5edb7fb1a65c276dc0ccea9e884362bc3e14e..d0a4aeaacc9c86fc70944e730c53f7a55644e05e 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/moe.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/moe.py @@ -6,10 +6,27 @@ import numpy as np import torch import torch.distributed as dist -import megablocks.ops as ops -from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all _LOAD_BALANCING_LOSS = [] @@ -158,7 +175,8 @@ class ParallelMLP(torch.nn.Module): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - output = ops.sort(top_expert, self.sort_end_bit) + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) assert output is not None bin_ids, indices = output @@ -168,10 +186,12 @@ class ParallelMLP(torch.nn.Module): # TODO(tgale): Does the sorted data produce a more favorable # data distribution for histogram? Or is the op parallelism # worth more? - tokens_per_expert = ops.histogram(top_expert, self.num_experts) + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins @@ -195,7 +215,8 @@ class ParallelMLP(torch.nn.Module): ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) assert output is not None x = output @@ -204,7 +225,9 @@ class ParallelMLP(torch.nn.Module): x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] @@ -264,7 +287,8 @@ class ParallelMLP(torch.nn.Module): # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. - repeated_tokens_per_expert = ops.repeat( + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( tokens_per_expert, (mpu.hidden_sharding_degree(self.args),), ) @@ -285,7 +309,8 @@ class ParallelMLP(torch.nn.Module): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - output = ops.gather(x, indices, bin_ids, bins, self.top_k) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) assert output is not None x = output @@ -317,7 +342,8 @@ class ParallelMLP(torch.nn.Module): # get all of the tokens assigned to them. # # TODO(tgale): Fuse this into the prior, local permutation. - x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) # Start the cross-device permutation asynchronously so we can # overlap communication with computation. @@ -336,7 +362,8 @@ class ParallelMLP(torch.nn.Module): # for expert computation we'll do one more local permutation. The # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. - replicate_bins = ops.inclusive_cumsum( + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( parallel_tokens_per_expert.flatten(), 0, ) @@ -351,14 +378,16 @@ class ParallelMLP(torch.nn.Module): ), mpu.experts_per_rank(self.args), ) - parallel_top_expert = ops.replicate( + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received, ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. - parallel_bin_ids, parallel_indices = ops.sort( + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( parallel_top_expert, self.sort_end_bit, ) @@ -368,7 +397,8 @@ class ParallelMLP(torch.nn.Module): dim=0, dtype=torch.int, ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens @@ -416,10 +446,12 @@ class ParallelMLP(torch.nn.Module): -1, self.args.hidden_size, ) - x = ops.sum(x.view(shape), dim=0) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mpu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mpu.py index b23213902f4567d7fdb0158cbcf5406c2b2aa601..434e143ab42bf3f83406d69e9dd1f72777716e22 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mpu.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mpu.py @@ -6,7 +6,8 @@ from typing import Optional import torch import torch.distributed as dist -from megablocks.layers.arguments import Arguments +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments class MoeParam(torch.Tensor): diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/router.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/router.py index 2c9dcd9e433322482a80d5b95afccee5c12368f8..37cb2782348d62583376f1a183c7ede83601216d 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/router.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/router.py @@ -4,8 +4,10 @@ from typing import Any import torch -from megablocks.layers import common -from megablocks.layers.arguments import Arguments +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments _ROUTER_LOGITS = [] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/sharedexpert_registry.py index 0f62db39a2dda5642f6baa9d9b949f7c31cf6d35..5840862f88f370ace5fd49bd0612fc98d186cc49 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/sharedexpert_registry.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/sharedexpert_registry.py @@ -3,8 +3,10 @@ from typing import Union -from megablocks.layers import glu, mlp -from megablocks.layers.arguments import Arguments +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments _REGISTRY = { 'mlp': mlp.SharedMLP, diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py index b9dc286a7b8c3d7dee98f318e5ebbf3aca7fc54c..b944080df810d0b0cfc571f3009b0098a651f9b7 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py @@ -1,20 +1,20 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from megablocks.ops.binned_gather import binned_gather -from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum -from megablocks.ops.gather import gather -from megablocks.ops.histogram import histogram -from megablocks.ops.padded_gather import padded_gather -from megablocks.ops.padded_scatter import padded_scatter -from megablocks.ops.repeat import repeat -from megablocks.ops.replicate import replicate -from megablocks.ops.round_up import round_up -from megablocks.ops.scatter import scatter -from megablocks.ops.sort import sort -from megablocks.ops.sum import sum -from megablocks.ops.topology import topology +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology __all__ = [ 'binned_gather', diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 47b95301dbb35b0c3154d1ff2878d20792ce5cb2..43d267dbe2570e4a1f59fd561398668cc2bc0920 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -4,8 +4,11 @@ import torch import torch.distributed as dist -from megablocks import benchmark_util -from megablocks.layers.all_to_all import all_to_all +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from ..layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py index 89cce1b627137d7f987baca62fd6d82c5c04659a..8ae2ad8388b06db46a13f8fa46083619c44eefe2 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_gather kernel. diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py index f5ce0d6f5a890a58b89e6a501216a9822341323f..6d8654bdc718cd893f7899e9cdb6cd20544d189c 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for binned_scatter kernel. diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py index 0d7c5137c900d66a959a76fa681436362a4df906..e2b7572391e20045d335cf7337246e8a9b9f57ef 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py @@ -11,7 +11,7 @@ import torch # instructions for building the c++ operations. try: # import megablocks_ops as ops # type: ignore - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py index 41b09a1233e8a996ff1062579cd0810d095ad1e6..4edf4541dac52abab94151efb414acaa7711f8f6 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for gather kernel. diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py index 77e06a29163aa335dc61e8c50932d5f83322ae1d..7b3f058ec373cbba7555704fb5e4212c3cc75d9d 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py index 9de8e65281c829ba63b5c60853dfa5bb0a333988..c57b7bf8228e01237236748147368b09ffdf8072 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py index bfa7b7c2f697f9454c5381e31ee78040be5b229e..9b2c7f062071758def741bf3af2392d3a4855f89 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -7,7 +7,7 @@ import stk import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py index f272a7768dc6468e81bd2cd25294ca16a6826c08..0ffe1369d6adbce5c4a54155c636d1a4c022a41d 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_gather kernel. diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py index 9ff81dd9456a04258346266e9e85871bda56c65b..6685b0b83e7b1d92e7a772715991ead8ad94b153 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py @@ -5,7 +5,7 @@ from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for padded_scatter kernel. diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py index 81dde4e4b4c0f550c6027390e8de285fa4d842f7..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py index 837f07e2858d5c65c14e3f262617e7d2985ce901..6536eeeae402659a087e5c51ef9840627af56501 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -6,7 +6,7 @@ import unittest import torch from absl.testing import parameterized -from megablocks import benchmark_util, ops +from .. import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py index e570e22093339d1b71dc830edf656897c67f40e7..26daf0eede330603a4b8ea7167faf1411d07ca93 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py index a5aaafc402a651a47595c25dbf71e382903e5022..3f4bacf3422543e11abb795c279876147f8610a8 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py @@ -6,7 +6,7 @@ from typing import Any, Optional import torch from stk.backend.autocast import custom_bwd, custom_fwd -from megablocks.backend import kernels +from ..backend import kernels # Autograd wrapper for scatter kernel. diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py index a64a4553323d8de71a1d5f1821bac50249c68197..bda3bf64283e39533c2eae3627e76bb2d0262c9f 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py index f28e3f2f22d28a1abd856f3b4aa6e33e7928b8ec..a92ff957d4c552c6e61d9279a7989795472af7b7 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -7,7 +7,7 @@ import numpy as np import torch from absl.testing import parameterized -from megablocks import ops +from .. import ops _SORT_TESTS = ( (16384, torch.int32, None), diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py index b4a693a82a36c052e51afb4a02077e7849a4abf1..76a50d3164db20534b099dcb4d8487a7aef25d15 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py @@ -10,7 +10,7 @@ import torch # Wrap this in a try-block with better error message and # instructions for building the c++ operations. try: - from megablocks._ops import ops # type: ignore + from .._ops import ops # type: ignore except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e