diff --git a/build.toml b/build.toml index 55de98f47d899f2f4fa91fc62fc47456d4f79e60..5c42dbdc22f1deee8e1ffc05a538ea4021dd38c4 100644 --- a/build.toml +++ b/build.toml @@ -21,7 +21,9 @@ cuda-capabilities = [ "9.0", "10.0", "10.1", - "12.0", + "11.8", + "12.0" + # "12.4" ] depends = ["torch", "cutlass_3_8"] src = [ diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py index e9d8a2495c4baab89e5f90d84fa92f406a29ea10..af8d40aa52c706d2aae368cc11ade68cf13f6a47 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py @@ -5,11 +5,15 @@ import torch from ._ops import ops -from megablocks.layers.arguments import Arguments -from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE -from megablocks.layers.glu import SparseGLU -from megablocks.layers.mlp import MLP, SparseMLP -from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from .grouped_gemm import backend as gg_backend +from .grouped_gemm import ops as gg_ops + + +from .layers.arguments import Arguments +from .layers.dmoe import ParallelDroplessMLP, dMoE +from .layers.glu import SparseGLU +from .layers.mlp import MLP, SparseMLP +from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so deleted file mode 100755 index 50fe814b70e1ecd024b8ad5bda99feb8bab489aa..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:36a918d61308a0acbc880516a42fcc2c4dc393f25655326e52128465c9501709 -size 10456376 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..81c1d31b49a30085121ea47c012801b4f6fe3a00 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44462d45f75616c369c2421fe41d53cd1d1dc365f1d2545d870e2db999e67e38 +size 10517608 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py index fcd42432c910375fb1713b3bf29145e48b556498..fb18e304d3a97b1801f4f28ccd9b74403f400f6f 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_359242d -ops = torch.ops._megablocks_359242d +from . import _megablocks_a585153_dirty +ops = torch.ops._megablocks_a585153_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_359242d::{op_name}" \ No newline at end of file + return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..7c692f95852950a7b97aaaa2aa3d15e4297b6192 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,32 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +from megablocks._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py index 6d3f977f360fd0ad5800c3b5da9ce57be794b9b8..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py @@ -4,7 +4,8 @@ import warnings _grouped_gemm_is_available: bool = False try: - import grouped_gemm + # import grouped_gemm + pass _grouped_gemm_is_available = True except ImportError as error: warnings.warn('Grouped GEMM not available.') @@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available(): assert _grouped_gemm_is_available, msg -backend = grouped_gemm.backend if grouped_gemm_is_available() else None -ops = grouped_gemm.ops if grouped_gemm_is_available() else None +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +from .grouped_gemm import backend as ops +from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py index 849b023b1a0f765fb1e26addf4670dc6db785a52..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # from megablocks.layers.dmoe import dMoE -from megablocks.layers.moe import MoE +from .moe import MoE __all__ = [ 'MoE', diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py index e9d8a2495c4baab89e5f90d84fa92f406a29ea10..af8d40aa52c706d2aae368cc11ade68cf13f6a47 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py @@ -5,11 +5,15 @@ import torch from ._ops import ops -from megablocks.layers.arguments import Arguments -from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE -from megablocks.layers.glu import SparseGLU -from megablocks.layers.mlp import MLP, SparseMLP -from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from .grouped_gemm import backend as gg_backend +from .grouped_gemm import ops as gg_ops + + +from .layers.arguments import Arguments +from .layers.dmoe import ParallelDroplessMLP, dMoE +from .layers.glu import SparseGLU +from .layers.mlp import MLP, SparseMLP +from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so deleted file mode 100755 index 519759fd66838a696df3771c278cb787e9c11dfe..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:12cf047c4bcb5f368f490ada3dafcecd1242a443ff5ded94c09d62548922098c -size 11795992 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d1d1046bc8625e4dd1bd60290bf9b8f659f3247b --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e734576700345e035790357ea19730e84e90c176747076ce845995bc3a0e0d50 +size 11869424 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py index fcd42432c910375fb1713b3bf29145e48b556498..fb18e304d3a97b1801f4f28ccd9b74403f400f6f 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_359242d -ops = torch.ops._megablocks_359242d +from . import _megablocks_a585153_dirty +ops = torch.ops._megablocks_a585153_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_359242d::{op_name}" \ No newline at end of file + return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..7c692f95852950a7b97aaaa2aa3d15e4297b6192 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,32 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +from megablocks._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py index 6d3f977f360fd0ad5800c3b5da9ce57be794b9b8..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py @@ -4,7 +4,8 @@ import warnings _grouped_gemm_is_available: bool = False try: - import grouped_gemm + # import grouped_gemm + pass _grouped_gemm_is_available = True except ImportError as error: warnings.warn('Grouped GEMM not available.') @@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available(): assert _grouped_gemm_is_available, msg -backend = grouped_gemm.backend if grouped_gemm_is_available() else None -ops = grouped_gemm.ops if grouped_gemm_is_available() else None +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +from .grouped_gemm import backend as ops +from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/__init__.py index 849b023b1a0f765fb1e26addf4670dc6db785a52..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/__init__.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # from megablocks.layers.dmoe import dMoE -from megablocks.layers.moe import MoE +from .moe import MoE __all__ = [ 'MoE', diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py index e9d8a2495c4baab89e5f90d84fa92f406a29ea10..af8d40aa52c706d2aae368cc11ade68cf13f6a47 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py @@ -5,11 +5,15 @@ import torch from ._ops import ops -from megablocks.layers.arguments import Arguments -from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE -from megablocks.layers.glu import SparseGLU -from megablocks.layers.mlp import MLP, SparseMLP -from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from .grouped_gemm import backend as gg_backend +from .grouped_gemm import ops as gg_ops + + +from .layers.arguments import Arguments +from .layers.dmoe import ParallelDroplessMLP, dMoE +from .layers.glu import SparseGLU +from .layers.mlp import MLP, SparseMLP +from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so deleted file mode 100755 index 73be500db9e30d7b8ad95919d224ae58090f5a70..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:38fcc4266dd94ee3f307bebd1264a1faaf91c73dc680adb72cd268748130b10f -size 11835888 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..dbff38a1256f476e63a04bf2d924acda689595c4 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8507dd1e6fc8f4df45af233d506ef96b962cacecf9e2d0694247547b0dd7dde0 +size 11931080 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py index fcd42432c910375fb1713b3bf29145e48b556498..fb18e304d3a97b1801f4f28ccd9b74403f400f6f 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_359242d -ops = torch.ops._megablocks_359242d +from . import _megablocks_a585153_dirty +ops = torch.ops._megablocks_a585153_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_359242d::{op_name}" \ No newline at end of file + return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..7c692f95852950a7b97aaaa2aa3d15e4297b6192 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,32 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +from megablocks._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py index 6d3f977f360fd0ad5800c3b5da9ce57be794b9b8..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py @@ -4,7 +4,8 @@ import warnings _grouped_gemm_is_available: bool = False try: - import grouped_gemm + # import grouped_gemm + pass _grouped_gemm_is_available = True except ImportError as error: warnings.warn('Grouped GEMM not available.') @@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available(): assert _grouped_gemm_is_available, msg -backend = grouped_gemm.backend if grouped_gemm_is_available() else None -ops = grouped_gemm.ops if grouped_gemm_is_available() else None +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +from .grouped_gemm import backend as ops +from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py index 849b023b1a0f765fb1e26addf4670dc6db785a52..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # from megablocks.layers.dmoe import dMoE -from megablocks.layers.moe import MoE +from .moe import MoE __all__ = [ 'MoE', diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py index e9d8a2495c4baab89e5f90d84fa92f406a29ea10..af8d40aa52c706d2aae368cc11ade68cf13f6a47 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py @@ -5,11 +5,15 @@ import torch from ._ops import ops -from megablocks.layers.arguments import Arguments -from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE -from megablocks.layers.glu import SparseGLU -from megablocks.layers.mlp import MLP, SparseMLP -from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from .grouped_gemm import backend as gg_backend +from .grouped_gemm import ops as gg_ops + + +from .layers.arguments import Arguments +from .layers.dmoe import ParallelDroplessMLP, dMoE +from .layers.glu import SparseGLU +from .layers.mlp import MLP, SparseMLP +from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so deleted file mode 100755 index 0be0cd1a5f3cab09576c25a7e8e799c8097bfdfc..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c9661a5704bd53f129788d2b9241f44cf00e1447dab8127d5a07675b1d6ca2ba -size 10444224 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d7813fa66f923338b76230f282683661a5ff4022 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6dc0dcea20fc1350689addf7cb9927f7bb709f68ed89d4c711b0f7db579a463b +size 10510072 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py index fcd42432c910375fb1713b3bf29145e48b556498..fb18e304d3a97b1801f4f28ccd9b74403f400f6f 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_359242d -ops = torch.ops._megablocks_359242d +from . import _megablocks_a585153_dirty +ops = torch.ops._megablocks_a585153_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_359242d::{op_name}" \ No newline at end of file + return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..7c692f95852950a7b97aaaa2aa3d15e4297b6192 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,32 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +from megablocks._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py index 6d3f977f360fd0ad5800c3b5da9ce57be794b9b8..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py @@ -4,7 +4,8 @@ import warnings _grouped_gemm_is_available: bool = False try: - import grouped_gemm + # import grouped_gemm + pass _grouped_gemm_is_available = True except ImportError as error: warnings.warn('Grouped GEMM not available.') @@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available(): assert _grouped_gemm_is_available, msg -backend = grouped_gemm.backend if grouped_gemm_is_available() else None -ops = grouped_gemm.ops if grouped_gemm_is_available() else None +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +from .grouped_gemm import backend as ops +from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/__init__.py index 849b023b1a0f765fb1e26addf4670dc6db785a52..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/__init__.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # from megablocks.layers.dmoe import dMoE -from megablocks.layers.moe import MoE +from .moe import MoE __all__ = [ 'MoE', diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py index e9d8a2495c4baab89e5f90d84fa92f406a29ea10..af8d40aa52c706d2aae368cc11ade68cf13f6a47 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py @@ -5,11 +5,15 @@ import torch from ._ops import ops -from megablocks.layers.arguments import Arguments -from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE -from megablocks.layers.glu import SparseGLU -from megablocks.layers.mlp import MLP, SparseMLP -from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from .grouped_gemm import backend as gg_backend +from .grouped_gemm import ops as gg_ops + + +from .layers.arguments import Arguments +from .layers.dmoe import ParallelDroplessMLP, dMoE +from .layers.glu import SparseGLU +from .layers.mlp import MLP, SparseMLP +from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so deleted file mode 100755 index abd969561ee41282ee024c3a0326ab521b4e7a00..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ec45bdb77d89916e5c58e17ddb76148799e3cac07fa6f4e93bf9140d9b2039bb -size 11788400 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e298e56ddd77607fd0010213f29fbfd2ceb6e2b0 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb0c2b91105c2f32f590aaa9d90ae2d6b36834bae9b35fb55c4b4fc90da56bc3 +size 11857952 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py index fcd42432c910375fb1713b3bf29145e48b556498..fb18e304d3a97b1801f4f28ccd9b74403f400f6f 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_359242d -ops = torch.ops._megablocks_359242d +from . import _megablocks_a585153_dirty +ops = torch.ops._megablocks_a585153_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_359242d::{op_name}" \ No newline at end of file + return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..7c692f95852950a7b97aaaa2aa3d15e4297b6192 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,32 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +from megablocks._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py index 6d3f977f360fd0ad5800c3b5da9ce57be794b9b8..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py @@ -4,7 +4,8 @@ import warnings _grouped_gemm_is_available: bool = False try: - import grouped_gemm + # import grouped_gemm + pass _grouped_gemm_is_available = True except ImportError as error: warnings.warn('Grouped GEMM not available.') @@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available(): assert _grouped_gemm_is_available, msg -backend = grouped_gemm.backend if grouped_gemm_is_available() else None -ops = grouped_gemm.ops if grouped_gemm_is_available() else None +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +from .grouped_gemm import backend as ops +from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/__init__.py index 849b023b1a0f765fb1e26addf4670dc6db785a52..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/__init__.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # from megablocks.layers.dmoe import dMoE -from megablocks.layers.moe import MoE +from .moe import MoE __all__ = [ 'MoE', diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py index e9d8a2495c4baab89e5f90d84fa92f406a29ea10..af8d40aa52c706d2aae368cc11ade68cf13f6a47 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py @@ -5,11 +5,15 @@ import torch from ._ops import ops -from megablocks.layers.arguments import Arguments -from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE -from megablocks.layers.glu import SparseGLU -from megablocks.layers.mlp import MLP, SparseMLP -from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from .grouped_gemm import backend as gg_backend +from .grouped_gemm import ops as gg_ops + + +from .layers.arguments import Arguments +from .layers.dmoe import ParallelDroplessMLP, dMoE +from .layers.glu import SparseGLU +from .layers.mlp import MLP, SparseMLP +from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so deleted file mode 100755 index 6ff924df77bd5ce379e5f717d59e29f407da7962..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:99859c18a9a4e7ec11e7fa805e2225644f8e5f51e2546b4525ddf8e939f48874 -size 11832392 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..bc0ec44d4aeb316188bc43b4704fef6028c9af92 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd817eed5069e786933346cb2bb5ab6f586878ae80647191932336dec3295c96 +size 11923704 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py index fcd42432c910375fb1713b3bf29145e48b556498..fb18e304d3a97b1801f4f28ccd9b74403f400f6f 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_359242d -ops = torch.ops._megablocks_359242d +from . import _megablocks_a585153_dirty +ops = torch.ops._megablocks_a585153_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_359242d::{op_name}" \ No newline at end of file + return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..7c692f95852950a7b97aaaa2aa3d15e4297b6192 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,32 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +from megablocks._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py index 6d3f977f360fd0ad5800c3b5da9ce57be794b9b8..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py @@ -4,7 +4,8 @@ import warnings _grouped_gemm_is_available: bool = False try: - import grouped_gemm + # import grouped_gemm + pass _grouped_gemm_is_available = True except ImportError as error: warnings.warn('Grouped GEMM not available.') @@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available(): assert _grouped_gemm_is_available, msg -backend = grouped_gemm.backend if grouped_gemm_is_available() else None -ops = grouped_gemm.ops if grouped_gemm_is_available() else None +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +from .grouped_gemm import backend as ops +from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/__init__.py index 849b023b1a0f765fb1e26addf4670dc6db785a52..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/__init__.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # from megablocks.layers.dmoe import dMoE -from megablocks.layers.moe import MoE +from .moe import MoE __all__ = [ 'MoE', diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py index e9d8a2495c4baab89e5f90d84fa92f406a29ea10..af8d40aa52c706d2aae368cc11ade68cf13f6a47 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py @@ -5,11 +5,15 @@ import torch from ._ops import ops -from megablocks.layers.arguments import Arguments -from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE -from megablocks.layers.glu import SparseGLU -from megablocks.layers.mlp import MLP, SparseMLP -from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from .grouped_gemm import backend as gg_backend +from .grouped_gemm import ops as gg_ops + + +from .layers.arguments import Arguments +from .layers.dmoe import ParallelDroplessMLP, dMoE +from .layers.glu import SparseGLU +from .layers.mlp import MLP, SparseMLP +from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so deleted file mode 100755 index a9116710809dc76b278bf71b0c78b35fa9c00c5f..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a1d87aabde7c05e3b4d3fd9623b89db0f9b7caa444a2c8f590fc8227211cd2d3 -size 10456616 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..b774ea5c271fea98a48c9c21d7fb1a824d73e065 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6edc24e0f999c3ecb6120715db285792c9f19dbe387348fe0e0f25c72e97138 +size 10517848 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py index fcd42432c910375fb1713b3bf29145e48b556498..fb18e304d3a97b1801f4f28ccd9b74403f400f6f 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_359242d -ops = torch.ops._megablocks_359242d +from . import _megablocks_a585153_dirty +ops = torch.ops._megablocks_a585153_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_359242d::{op_name}" \ No newline at end of file + return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..7c692f95852950a7b97aaaa2aa3d15e4297b6192 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,32 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +from megablocks._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py index 6d3f977f360fd0ad5800c3b5da9ce57be794b9b8..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py @@ -4,7 +4,8 @@ import warnings _grouped_gemm_is_available: bool = False try: - import grouped_gemm + # import grouped_gemm + pass _grouped_gemm_is_available = True except ImportError as error: warnings.warn('Grouped GEMM not available.') @@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available(): assert _grouped_gemm_is_available, msg -backend = grouped_gemm.backend if grouped_gemm_is_available() else None -ops = grouped_gemm.ops if grouped_gemm_is_available() else None +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +from .grouped_gemm import backend as ops +from .grouped_gemm import ops as backend diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py index 849b023b1a0f765fb1e26addf4670dc6db785a52..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # from megablocks.layers.dmoe import dMoE -from megablocks.layers.moe import MoE +from .moe import MoE __all__ = [ 'MoE', diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py index e9d8a2495c4baab89e5f90d84fa92f406a29ea10..af8d40aa52c706d2aae368cc11ade68cf13f6a47 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py @@ -5,11 +5,15 @@ import torch from ._ops import ops -from megablocks.layers.arguments import Arguments -from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE -from megablocks.layers.glu import SparseGLU -from megablocks.layers.mlp import MLP, SparseMLP -from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from .grouped_gemm import backend as gg_backend +from .grouped_gemm import ops as gg_ops + + +from .layers.arguments import Arguments +from .layers.dmoe import ParallelDroplessMLP, dMoE +from .layers.glu import SparseGLU +from .layers.mlp import MLP, SparseMLP +from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so deleted file mode 100755 index 7d62280654256c29f19dc4a96c5b895591d064be..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9d15bbd83210a53d26b152adb3388900213fd5ff70d5b77a61acac53b6e89fbe -size 11835920 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..6485829ebc3f4f120eac78ed896d64000416c3ac --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8ad568311ee3b60bd6c53500a01213c767afd6d455b10f0b4e23a1a162839d8 +size 11931112 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py index fcd42432c910375fb1713b3bf29145e48b556498..fb18e304d3a97b1801f4f28ccd9b74403f400f6f 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_359242d -ops = torch.ops._megablocks_359242d +from . import _megablocks_a585153_dirty +ops = torch.ops._megablocks_a585153_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_359242d::{op_name}" \ No newline at end of file + return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..7c692f95852950a7b97aaaa2aa3d15e4297b6192 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,32 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +from megablocks._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py index 6d3f977f360fd0ad5800c3b5da9ce57be794b9b8..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py @@ -4,7 +4,8 @@ import warnings _grouped_gemm_is_available: bool = False try: - import grouped_gemm + # import grouped_gemm + pass _grouped_gemm_is_available = True except ImportError as error: warnings.warn('Grouped GEMM not available.') @@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available(): assert _grouped_gemm_is_available, msg -backend = grouped_gemm.backend if grouped_gemm_is_available() else None -ops = grouped_gemm.ops if grouped_gemm_is_available() else None +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +from .grouped_gemm import backend as ops +from .grouped_gemm import ops as backend diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py index 849b023b1a0f765fb1e26addf4670dc6db785a52..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # from megablocks.layers.dmoe import dMoE -from megablocks.layers.moe import MoE +from .moe import MoE __all__ = [ 'MoE', diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py index e9d8a2495c4baab89e5f90d84fa92f406a29ea10..af8d40aa52c706d2aae368cc11ade68cf13f6a47 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py @@ -5,11 +5,15 @@ import torch from ._ops import ops -from megablocks.layers.arguments import Arguments -from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE -from megablocks.layers.glu import SparseGLU -from megablocks.layers.mlp import MLP, SparseMLP -from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from .grouped_gemm import backend as gg_backend +from .grouped_gemm import ops as gg_ops + + +from .layers.arguments import Arguments +from .layers.dmoe import ParallelDroplessMLP, dMoE +from .layers.glu import SparseGLU +from .layers.mlp import MLP, SparseMLP +from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_359242d.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_359242d.abi3.so deleted file mode 100755 index f58a6ba85d95e023464dafbb7d4f75d60bf639f1..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_359242d.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:122a39925ccc50597aac33bcd4eec4d2d3ff29846c3a79671da2c554d4967777 -size 17748280 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d2b18b163b48d73368dc51792783bbddb2c00622 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_a585153_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca88225928254532529d6a45b819b44f4932085ac90d966ea31a824777c4d581 +size 17892656 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py index fcd42432c910375fb1713b3bf29145e48b556498..fb18e304d3a97b1801f4f28ccd9b74403f400f6f 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_359242d -ops = torch.ops._megablocks_359242d +from . import _megablocks_a585153_dirty +ops = torch.ops._megablocks_a585153_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_359242d::{op_name}" \ No newline at end of file + return f"_megablocks_a585153_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..7c692f95852950a7b97aaaa2aa3d15e4297b6192 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,32 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +from megablocks._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py index 6d3f977f360fd0ad5800c3b5da9ce57be794b9b8..1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py @@ -4,7 +4,8 @@ import warnings _grouped_gemm_is_available: bool = False try: - import grouped_gemm + # import grouped_gemm + pass _grouped_gemm_is_available = True except ImportError as error: warnings.warn('Grouped GEMM not available.') @@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available(): assert _grouped_gemm_is_available, msg -backend = grouped_gemm.backend if grouped_gemm_is_available() else None -ops = grouped_gemm.ops if grouped_gemm_is_available() else None +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +from .grouped_gemm import backend as ops +from .grouped_gemm import ops as backend diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/__init__.py index 849b023b1a0f765fb1e26addf4670dc6db785a52..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/__init__.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # from megablocks.layers.dmoe import dMoE -from megablocks.layers.moe import MoE +from .moe import MoE __all__ = [ 'MoE', diff --git a/tests/ops_test.py b/tests/ops_test.py index 3771e38362915fececbf64ff9dea971f65ab9cb4..f94388dba3b8e738d918db2537bab7d13349c2a4 100644 --- a/tests/ops_test.py +++ b/tests/ops_test.py @@ -1,10 +1,11 @@ -import unittest -import itertools +import torch +import megablocks +import unittest from absl.testing import parameterized -import megablocks -import numpy as np -import torch + +# import itertools +# import numpy as np def allclose(x, y, pct=2.0):