kernel
drbh commited on
Commit
6756875
·
1 Parent(s): 3224250

fix: bump builds

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build.toml +3 -1
  2. build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py +9 -5
  3. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_359242d.abi3.so → _megablocks_a585153_dirty.abi3.so} +2 -2
  4. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  5. build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py +2 -0
  6. build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +32 -0
  7. build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py +33 -0
  8. build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +8 -3
  9. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py +1 -1
  10. build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py +9 -5
  11. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_359242d.abi3.so → _megablocks_a585153_dirty.abi3.so} +2 -2
  12. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  13. build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py +2 -0
  14. build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py +32 -0
  15. build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py +33 -0
  16. build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py +8 -3
  17. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/__init__.py +1 -1
  18. build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py +9 -5
  19. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_359242d.abi3.so → _megablocks_a585153_dirty.abi3.so} +2 -2
  20. build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  21. build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py +2 -0
  22. build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py +32 -0
  23. build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py +33 -0
  24. build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py +8 -3
  25. build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py +1 -1
  26. build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py +9 -5
  27. build/torch26-cxx98-cu118-x86_64-linux/megablocks/{_megablocks_359242d.abi3.so → _megablocks_a585153_dirty.abi3.so} +2 -2
  28. build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  29. build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py +2 -0
  30. build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +32 -0
  31. build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py +33 -0
  32. build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +8 -3
  33. build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/__init__.py +1 -1
  34. build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py +9 -5
  35. build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +0 -3
  36. build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so +3 -0
  37. build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  38. build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py +2 -0
  39. build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py +32 -0
  40. build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py +33 -0
  41. build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py +8 -3
  42. build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/__init__.py +1 -1
  43. build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py +9 -5
  44. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +0 -3
  45. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so +3 -0
  46. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  47. build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py +2 -0
  48. build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py +32 -0
  49. build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py +33 -0
  50. build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py +8 -3
build.toml CHANGED
@@ -21,7 +21,9 @@ cuda-capabilities = [
21
  "9.0",
22
  "10.0",
23
  "10.1",
24
- "12.0",
 
 
25
  ]
26
  depends = ["torch", "cutlass_3_8"]
27
  src = [
 
21
  "9.0",
22
  "10.0",
23
  "10.1",
24
+ "11.8",
25
+ "12.0"
26
+ # "12.4"
27
  ]
28
  depends = ["torch", "cutlass_3_8"]
29
  src = [
build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py CHANGED
@@ -5,11 +5,15 @@ import torch
5
 
6
  from ._ops import ops
7
 
8
- from megablocks.layers.arguments import Arguments
9
- from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
10
- from megablocks.layers.glu import SparseGLU
11
- from megablocks.layers.mlp import MLP, SparseMLP
12
- from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
 
 
 
 
13
 
14
  # This section contains the direct kernel exports (not inlcuded in the original code)
15
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
 
5
 
6
  from ._ops import ops
7
 
8
+ from .grouped_gemm import backend as gg_backend
9
+ from .grouped_gemm import ops as gg_ops
10
+
11
+
12
+ from .layers.arguments import Arguments
13
+ from .layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from .layers.glu import SparseGLU
15
+ from .layers.mlp import MLP, SparseMLP
16
+ from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
 
18
  # This section contains the direct kernel exports (not inlcuded in the original code)
19
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_359242d.abi3.so → _megablocks_a585153_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:36a918d61308a0acbc880516a42fcc2c4dc393f25655326e52128465c9501709
3
- size 10456376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44462d45f75616c369c2421fe41d53cd1d1dc365f1d2545d870e2db999e67e38
3
+ size 10517608
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_359242d
3
- ops = torch.ops._megablocks_359242d
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_359242d::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_a585153_dirty
3
+ ops = torch.ops._megablocks_a585153_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_a585153_dirty::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import ops
2
+ from . import backend
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+
5
+ # # TODO(tgale): Wrap this in a try-block with better
6
+ # # error message and instructions for building the
7
+ # # c++ operations.
8
+ # import grouped_gemm_backend as backend
9
+
10
+ # We import the backend operations from the megablocks package as
11
+ # grouped_gemm is vendored in megablocks in this repository.
12
+ # from ... import _ops as backend
13
+ from megablocks._ops import ops as backend # type: ignore
14
+
15
+ def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
16
+ assert not (trans_a and trans_b)
17
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
18
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
19
+ assert b.ndim == (2 if trans_a else 3)
20
+
21
+ shape = (
22
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
23
+ if trans_a else
24
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
25
+ )
26
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
27
+
28
+ def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
29
+ if c is None:
30
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
31
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
32
+ return c
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import backend
2
+ import torch
3
+
4
+
5
+ class GroupedGemm(torch.autograd.Function):
6
+
7
+ @staticmethod
8
+ def forward(ctx, a, b, batch_sizes, trans_b):
9
+ ctx.save_for_backward(a, b, batch_sizes)
10
+ ctx.trans_b = trans_b
11
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
12
+
13
+ @staticmethod
14
+ def backward(ctx, grad):
15
+ grad = grad.contiguous()
16
+ a, b, batch_sizes = ctx.saved_tensors
17
+ trans_b = ctx.trans_b
18
+
19
+ agrad = None
20
+ if ctx.needs_input_grad[0]:
21
+ agrad = backend.gmm(
22
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
23
+
24
+ bgrad = None
25
+ if ctx.needs_input_grad[1]:
26
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
27
+ bgrad = backend.gmm(
28
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
29
+ return agrad, bgrad, None, None
30
+
31
+
32
+ def gmm(a, b, batch_sizes, trans_b=False):
33
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py CHANGED
@@ -4,7 +4,8 @@ import warnings
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
- import grouped_gemm
 
8
  _grouped_gemm_is_available = True
9
  except ImportError as error:
10
  warnings.warn('Grouped GEMM not available.')
@@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available():
22
  assert _grouped_gemm_is_available, msg
23
 
24
 
25
- backend = grouped_gemm.backend if grouped_gemm_is_available() else None
26
- ops = grouped_gemm.ops if grouped_gemm_is_available() else None
 
 
 
 
 
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
+ # import grouped_gemm
8
+ pass
9
  _grouped_gemm_is_available = True
10
  except ImportError as error:
11
  warnings.warn('Grouped GEMM not available.')
 
23
  assert _grouped_gemm_is_available, msg
24
 
25
 
26
+ # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
27
+ # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
28
+
29
+
30
+ from .grouped_gemm import backend as ops
31
+ from .grouped_gemm import ops as backend
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
- from megablocks.layers.moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
 
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
+ from .moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py CHANGED
@@ -5,11 +5,15 @@ import torch
5
 
6
  from ._ops import ops
7
 
8
- from megablocks.layers.arguments import Arguments
9
- from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
10
- from megablocks.layers.glu import SparseGLU
11
- from megablocks.layers.mlp import MLP, SparseMLP
12
- from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
 
 
 
 
13
 
14
  # This section contains the direct kernel exports (not inlcuded in the original code)
15
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
 
5
 
6
  from ._ops import ops
7
 
8
+ from .grouped_gemm import backend as gg_backend
9
+ from .grouped_gemm import ops as gg_ops
10
+
11
+
12
+ from .layers.arguments import Arguments
13
+ from .layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from .layers.glu import SparseGLU
15
+ from .layers.mlp import MLP, SparseMLP
16
+ from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
 
18
  # This section contains the direct kernel exports (not inlcuded in the original code)
19
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_359242d.abi3.so → _megablocks_a585153_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:12cf047c4bcb5f368f490ada3dafcecd1242a443ff5ded94c09d62548922098c
3
- size 11795992
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e734576700345e035790357ea19730e84e90c176747076ce845995bc3a0e0d50
3
+ size 11869424
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_359242d
3
- ops = torch.ops._megablocks_359242d
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_359242d::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_a585153_dirty
3
+ ops = torch.ops._megablocks_a585153_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_a585153_dirty::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import ops
2
+ from . import backend
build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+
5
+ # # TODO(tgale): Wrap this in a try-block with better
6
+ # # error message and instructions for building the
7
+ # # c++ operations.
8
+ # import grouped_gemm_backend as backend
9
+
10
+ # We import the backend operations from the megablocks package as
11
+ # grouped_gemm is vendored in megablocks in this repository.
12
+ # from ... import _ops as backend
13
+ from megablocks._ops import ops as backend # type: ignore
14
+
15
+ def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
16
+ assert not (trans_a and trans_b)
17
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
18
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
19
+ assert b.ndim == (2 if trans_a else 3)
20
+
21
+ shape = (
22
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
23
+ if trans_a else
24
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
25
+ )
26
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
27
+
28
+ def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
29
+ if c is None:
30
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
31
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
32
+ return c
build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import backend
2
+ import torch
3
+
4
+
5
+ class GroupedGemm(torch.autograd.Function):
6
+
7
+ @staticmethod
8
+ def forward(ctx, a, b, batch_sizes, trans_b):
9
+ ctx.save_for_backward(a, b, batch_sizes)
10
+ ctx.trans_b = trans_b
11
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
12
+
13
+ @staticmethod
14
+ def backward(ctx, grad):
15
+ grad = grad.contiguous()
16
+ a, b, batch_sizes = ctx.saved_tensors
17
+ trans_b = ctx.trans_b
18
+
19
+ agrad = None
20
+ if ctx.needs_input_grad[0]:
21
+ agrad = backend.gmm(
22
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
23
+
24
+ bgrad = None
25
+ if ctx.needs_input_grad[1]:
26
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
27
+ bgrad = backend.gmm(
28
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
29
+ return agrad, bgrad, None, None
30
+
31
+
32
+ def gmm(a, b, batch_sizes, trans_b=False):
33
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py CHANGED
@@ -4,7 +4,8 @@ import warnings
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
- import grouped_gemm
 
8
  _grouped_gemm_is_available = True
9
  except ImportError as error:
10
  warnings.warn('Grouped GEMM not available.')
@@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available():
22
  assert _grouped_gemm_is_available, msg
23
 
24
 
25
- backend = grouped_gemm.backend if grouped_gemm_is_available() else None
26
- ops = grouped_gemm.ops if grouped_gemm_is_available() else None
 
 
 
 
 
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
+ # import grouped_gemm
8
+ pass
9
  _grouped_gemm_is_available = True
10
  except ImportError as error:
11
  warnings.warn('Grouped GEMM not available.')
 
23
  assert _grouped_gemm_is_available, msg
24
 
25
 
26
+ # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
27
+ # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
28
+
29
+
30
+ from .grouped_gemm import backend as ops
31
+ from .grouped_gemm import ops as backend
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
- from megablocks.layers.moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
 
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
+ from .moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py CHANGED
@@ -5,11 +5,15 @@ import torch
5
 
6
  from ._ops import ops
7
 
8
- from megablocks.layers.arguments import Arguments
9
- from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
10
- from megablocks.layers.glu import SparseGLU
11
- from megablocks.layers.mlp import MLP, SparseMLP
12
- from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
 
 
 
 
13
 
14
  # This section contains the direct kernel exports (not inlcuded in the original code)
15
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
 
5
 
6
  from ._ops import ops
7
 
8
+ from .grouped_gemm import backend as gg_backend
9
+ from .grouped_gemm import ops as gg_ops
10
+
11
+
12
+ from .layers.arguments import Arguments
13
+ from .layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from .layers.glu import SparseGLU
15
+ from .layers.mlp import MLP, SparseMLP
16
+ from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
 
18
  # This section contains the direct kernel exports (not inlcuded in the original code)
19
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_359242d.abi3.so → _megablocks_a585153_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:38fcc4266dd94ee3f307bebd1264a1faaf91c73dc680adb72cd268748130b10f
3
- size 11835888
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8507dd1e6fc8f4df45af233d506ef96b962cacecf9e2d0694247547b0dd7dde0
3
+ size 11931080
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_359242d
3
- ops = torch.ops._megablocks_359242d
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_359242d::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_a585153_dirty
3
+ ops = torch.ops._megablocks_a585153_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_a585153_dirty::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import ops
2
+ from . import backend
build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+
5
+ # # TODO(tgale): Wrap this in a try-block with better
6
+ # # error message and instructions for building the
7
+ # # c++ operations.
8
+ # import grouped_gemm_backend as backend
9
+
10
+ # We import the backend operations from the megablocks package as
11
+ # grouped_gemm is vendored in megablocks in this repository.
12
+ # from ... import _ops as backend
13
+ from megablocks._ops import ops as backend # type: ignore
14
+
15
+ def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
16
+ assert not (trans_a and trans_b)
17
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
18
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
19
+ assert b.ndim == (2 if trans_a else 3)
20
+
21
+ shape = (
22
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
23
+ if trans_a else
24
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
25
+ )
26
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
27
+
28
+ def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
29
+ if c is None:
30
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
31
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
32
+ return c
build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import backend
2
+ import torch
3
+
4
+
5
+ class GroupedGemm(torch.autograd.Function):
6
+
7
+ @staticmethod
8
+ def forward(ctx, a, b, batch_sizes, trans_b):
9
+ ctx.save_for_backward(a, b, batch_sizes)
10
+ ctx.trans_b = trans_b
11
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
12
+
13
+ @staticmethod
14
+ def backward(ctx, grad):
15
+ grad = grad.contiguous()
16
+ a, b, batch_sizes = ctx.saved_tensors
17
+ trans_b = ctx.trans_b
18
+
19
+ agrad = None
20
+ if ctx.needs_input_grad[0]:
21
+ agrad = backend.gmm(
22
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
23
+
24
+ bgrad = None
25
+ if ctx.needs_input_grad[1]:
26
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
27
+ bgrad = backend.gmm(
28
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
29
+ return agrad, bgrad, None, None
30
+
31
+
32
+ def gmm(a, b, batch_sizes, trans_b=False):
33
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py CHANGED
@@ -4,7 +4,8 @@ import warnings
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
- import grouped_gemm
 
8
  _grouped_gemm_is_available = True
9
  except ImportError as error:
10
  warnings.warn('Grouped GEMM not available.')
@@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available():
22
  assert _grouped_gemm_is_available, msg
23
 
24
 
25
- backend = grouped_gemm.backend if grouped_gemm_is_available() else None
26
- ops = grouped_gemm.ops if grouped_gemm_is_available() else None
 
 
 
 
 
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
+ # import grouped_gemm
8
+ pass
9
  _grouped_gemm_is_available = True
10
  except ImportError as error:
11
  warnings.warn('Grouped GEMM not available.')
 
23
  assert _grouped_gemm_is_available, msg
24
 
25
 
26
+ # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
27
+ # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
28
+
29
+
30
+ from .grouped_gemm import backend as ops
31
+ from .grouped_gemm import ops as backend
build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
- from megablocks.layers.moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
 
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
+ from .moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py CHANGED
@@ -5,11 +5,15 @@ import torch
5
 
6
  from ._ops import ops
7
 
8
- from megablocks.layers.arguments import Arguments
9
- from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
10
- from megablocks.layers.glu import SparseGLU
11
- from megablocks.layers.mlp import MLP, SparseMLP
12
- from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
 
 
 
 
13
 
14
  # This section contains the direct kernel exports (not inlcuded in the original code)
15
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
 
5
 
6
  from ._ops import ops
7
 
8
+ from .grouped_gemm import backend as gg_backend
9
+ from .grouped_gemm import ops as gg_ops
10
+
11
+
12
+ from .layers.arguments import Arguments
13
+ from .layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from .layers.glu import SparseGLU
15
+ from .layers.mlp import MLP, SparseMLP
16
+ from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
 
18
  # This section contains the direct kernel exports (not inlcuded in the original code)
19
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
build/torch26-cxx98-cu118-x86_64-linux/megablocks/{_megablocks_359242d.abi3.so → _megablocks_a585153_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c9661a5704bd53f129788d2b9241f44cf00e1447dab8127d5a07675b1d6ca2ba
3
- size 10444224
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dc0dcea20fc1350689addf7cb9927f7bb709f68ed89d4c711b0f7db579a463b
3
+ size 10510072
build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_359242d
3
- ops = torch.ops._megablocks_359242d
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_359242d::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_a585153_dirty
3
+ ops = torch.ops._megablocks_a585153_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_a585153_dirty::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import ops
2
+ from . import backend
build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+
5
+ # # TODO(tgale): Wrap this in a try-block with better
6
+ # # error message and instructions for building the
7
+ # # c++ operations.
8
+ # import grouped_gemm_backend as backend
9
+
10
+ # We import the backend operations from the megablocks package as
11
+ # grouped_gemm is vendored in megablocks in this repository.
12
+ # from ... import _ops as backend
13
+ from megablocks._ops import ops as backend # type: ignore
14
+
15
+ def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
16
+ assert not (trans_a and trans_b)
17
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
18
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
19
+ assert b.ndim == (2 if trans_a else 3)
20
+
21
+ shape = (
22
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
23
+ if trans_a else
24
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
25
+ )
26
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
27
+
28
+ def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
29
+ if c is None:
30
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
31
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
32
+ return c
build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import backend
2
+ import torch
3
+
4
+
5
+ class GroupedGemm(torch.autograd.Function):
6
+
7
+ @staticmethod
8
+ def forward(ctx, a, b, batch_sizes, trans_b):
9
+ ctx.save_for_backward(a, b, batch_sizes)
10
+ ctx.trans_b = trans_b
11
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
12
+
13
+ @staticmethod
14
+ def backward(ctx, grad):
15
+ grad = grad.contiguous()
16
+ a, b, batch_sizes = ctx.saved_tensors
17
+ trans_b = ctx.trans_b
18
+
19
+ agrad = None
20
+ if ctx.needs_input_grad[0]:
21
+ agrad = backend.gmm(
22
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
23
+
24
+ bgrad = None
25
+ if ctx.needs_input_grad[1]:
26
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
27
+ bgrad = backend.gmm(
28
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
29
+ return agrad, bgrad, None, None
30
+
31
+
32
+ def gmm(a, b, batch_sizes, trans_b=False):
33
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py CHANGED
@@ -4,7 +4,8 @@ import warnings
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
- import grouped_gemm
 
8
  _grouped_gemm_is_available = True
9
  except ImportError as error:
10
  warnings.warn('Grouped GEMM not available.')
@@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available():
22
  assert _grouped_gemm_is_available, msg
23
 
24
 
25
- backend = grouped_gemm.backend if grouped_gemm_is_available() else None
26
- ops = grouped_gemm.ops if grouped_gemm_is_available() else None
 
 
 
 
 
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
+ # import grouped_gemm
8
+ pass
9
  _grouped_gemm_is_available = True
10
  except ImportError as error:
11
  warnings.warn('Grouped GEMM not available.')
 
23
  assert _grouped_gemm_is_available, msg
24
 
25
 
26
+ # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
27
+ # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
28
+
29
+
30
+ from .grouped_gemm import backend as ops
31
+ from .grouped_gemm import ops as backend
build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
- from megablocks.layers.moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
 
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
+ from .moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py CHANGED
@@ -5,11 +5,15 @@ import torch
5
 
6
  from ._ops import ops
7
 
8
- from megablocks.layers.arguments import Arguments
9
- from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
10
- from megablocks.layers.glu import SparseGLU
11
- from megablocks.layers.mlp import MLP, SparseMLP
12
- from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
 
 
 
 
13
 
14
  # This section contains the direct kernel exports (not inlcuded in the original code)
15
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
 
5
 
6
  from ._ops import ops
7
 
8
+ from .grouped_gemm import backend as gg_backend
9
+ from .grouped_gemm import ops as gg_ops
10
+
11
+
12
+ from .layers.arguments import Arguments
13
+ from .layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from .layers.glu import SparseGLU
15
+ from .layers.mlp import MLP, SparseMLP
16
+ from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
 
18
  # This section contains the direct kernel exports (not inlcuded in the original code)
19
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ec45bdb77d89916e5c58e17ddb76148799e3cac07fa6f4e93bf9140d9b2039bb
3
- size 11788400
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb0c2b91105c2f32f590aaa9d90ae2d6b36834bae9b35fb55c4b4fc90da56bc3
3
+ size 11857952
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_359242d
3
- ops = torch.ops._megablocks_359242d
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_359242d::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_a585153_dirty
3
+ ops = torch.ops._megablocks_a585153_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_a585153_dirty::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import ops
2
+ from . import backend
build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+
5
+ # # TODO(tgale): Wrap this in a try-block with better
6
+ # # error message and instructions for building the
7
+ # # c++ operations.
8
+ # import grouped_gemm_backend as backend
9
+
10
+ # We import the backend operations from the megablocks package as
11
+ # grouped_gemm is vendored in megablocks in this repository.
12
+ # from ... import _ops as backend
13
+ from megablocks._ops import ops as backend # type: ignore
14
+
15
+ def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
16
+ assert not (trans_a and trans_b)
17
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
18
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
19
+ assert b.ndim == (2 if trans_a else 3)
20
+
21
+ shape = (
22
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
23
+ if trans_a else
24
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
25
+ )
26
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
27
+
28
+ def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
29
+ if c is None:
30
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
31
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
32
+ return c
build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import backend
2
+ import torch
3
+
4
+
5
+ class GroupedGemm(torch.autograd.Function):
6
+
7
+ @staticmethod
8
+ def forward(ctx, a, b, batch_sizes, trans_b):
9
+ ctx.save_for_backward(a, b, batch_sizes)
10
+ ctx.trans_b = trans_b
11
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
12
+
13
+ @staticmethod
14
+ def backward(ctx, grad):
15
+ grad = grad.contiguous()
16
+ a, b, batch_sizes = ctx.saved_tensors
17
+ trans_b = ctx.trans_b
18
+
19
+ agrad = None
20
+ if ctx.needs_input_grad[0]:
21
+ agrad = backend.gmm(
22
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
23
+
24
+ bgrad = None
25
+ if ctx.needs_input_grad[1]:
26
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
27
+ bgrad = backend.gmm(
28
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
29
+ return agrad, bgrad, None, None
30
+
31
+
32
+ def gmm(a, b, batch_sizes, trans_b=False):
33
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py CHANGED
@@ -4,7 +4,8 @@ import warnings
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
- import grouped_gemm
 
8
  _grouped_gemm_is_available = True
9
  except ImportError as error:
10
  warnings.warn('Grouped GEMM not available.')
@@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available():
22
  assert _grouped_gemm_is_available, msg
23
 
24
 
25
- backend = grouped_gemm.backend if grouped_gemm_is_available() else None
26
- ops = grouped_gemm.ops if grouped_gemm_is_available() else None
 
 
 
 
 
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
+ # import grouped_gemm
8
+ pass
9
  _grouped_gemm_is_available = True
10
  except ImportError as error:
11
  warnings.warn('Grouped GEMM not available.')
 
23
  assert _grouped_gemm_is_available, msg
24
 
25
 
26
+ # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
27
+ # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
28
+
29
+
30
+ from .grouped_gemm import backend as ops
31
+ from .grouped_gemm import ops as backend
build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
- from megablocks.layers.moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
 
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
+ from .moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py CHANGED
@@ -5,11 +5,15 @@ import torch
5
 
6
  from ._ops import ops
7
 
8
- from megablocks.layers.arguments import Arguments
9
- from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
10
- from megablocks.layers.glu import SparseGLU
11
- from megablocks.layers.mlp import MLP, SparseMLP
12
- from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
 
 
 
 
13
 
14
  # This section contains the direct kernel exports (not inlcuded in the original code)
15
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
 
5
 
6
  from ._ops import ops
7
 
8
+ from .grouped_gemm import backend as gg_backend
9
+ from .grouped_gemm import ops as gg_ops
10
+
11
+
12
+ from .layers.arguments import Arguments
13
+ from .layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from .layers.glu import SparseGLU
15
+ from .layers.mlp import MLP, SparseMLP
16
+ from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
 
18
  # This section contains the direct kernel exports (not inlcuded in the original code)
19
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:99859c18a9a4e7ec11e7fa805e2225644f8e5f51e2546b4525ddf8e939f48874
3
- size 11832392
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd817eed5069e786933346cb2bb5ab6f586878ae80647191932336dec3295c96
3
+ size 11923704
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_359242d
3
- ops = torch.ops._megablocks_359242d
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_359242d::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_a585153_dirty
3
+ ops = torch.ops._megablocks_a585153_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_a585153_dirty::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import ops
2
+ from . import backend
build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+
5
+ # # TODO(tgale): Wrap this in a try-block with better
6
+ # # error message and instructions for building the
7
+ # # c++ operations.
8
+ # import grouped_gemm_backend as backend
9
+
10
+ # We import the backend operations from the megablocks package as
11
+ # grouped_gemm is vendored in megablocks in this repository.
12
+ # from ... import _ops as backend
13
+ from megablocks._ops import ops as backend # type: ignore
14
+
15
+ def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
16
+ assert not (trans_a and trans_b)
17
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
18
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
19
+ assert b.ndim == (2 if trans_a else 3)
20
+
21
+ shape = (
22
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
23
+ if trans_a else
24
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
25
+ )
26
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
27
+
28
+ def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
29
+ if c is None:
30
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
31
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
32
+ return c
build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import backend
2
+ import torch
3
+
4
+
5
+ class GroupedGemm(torch.autograd.Function):
6
+
7
+ @staticmethod
8
+ def forward(ctx, a, b, batch_sizes, trans_b):
9
+ ctx.save_for_backward(a, b, batch_sizes)
10
+ ctx.trans_b = trans_b
11
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
12
+
13
+ @staticmethod
14
+ def backward(ctx, grad):
15
+ grad = grad.contiguous()
16
+ a, b, batch_sizes = ctx.saved_tensors
17
+ trans_b = ctx.trans_b
18
+
19
+ agrad = None
20
+ if ctx.needs_input_grad[0]:
21
+ agrad = backend.gmm(
22
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
23
+
24
+ bgrad = None
25
+ if ctx.needs_input_grad[1]:
26
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
27
+ bgrad = backend.gmm(
28
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
29
+ return agrad, bgrad, None, None
30
+
31
+
32
+ def gmm(a, b, batch_sizes, trans_b=False):
33
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py CHANGED
@@ -4,7 +4,8 @@ import warnings
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
- import grouped_gemm
 
8
  _grouped_gemm_is_available = True
9
  except ImportError as error:
10
  warnings.warn('Grouped GEMM not available.')
@@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available():
22
  assert _grouped_gemm_is_available, msg
23
 
24
 
25
- backend = grouped_gemm.backend if grouped_gemm_is_available() else None
26
- ops = grouped_gemm.ops if grouped_gemm_is_available() else None
 
 
 
 
 
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
+ # import grouped_gemm
8
+ pass
9
  _grouped_gemm_is_available = True
10
  except ImportError as error:
11
  warnings.warn('Grouped GEMM not available.')
 
23
  assert _grouped_gemm_is_available, msg
24
 
25
 
26
+ # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
27
+ # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
28
+
29
+
30
+ from .grouped_gemm import backend as ops
31
+ from .grouped_gemm import ops as backend