kernel
drbh commited on
Commit
9c4ca75
·
1 Parent(s): a4f6452

feat: validate build with original test suite

Browse files
README.md CHANGED
@@ -4,3 +4,39 @@ tags:
4
  - kernel
5
  ---
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  - kernel
5
  ---
6
 
7
+
8
+
9
+ ```bash
10
+ nix develop --show-trace -i -L .#test --command python -m pytest -s tests
11
+ ```
12
+
13
+ expected output:
14
+
15
+ ```
16
+ ============== test session starts ===============
17
+ platform linux -- Python 3.12.10, pytest-8.3.5, pluggy-1.5.0
18
+ rootdir: /home/ubuntu/Projects/megablocks-moe
19
+ plugins: hypothesis-6.130.12
20
+ collecting 43 items world_size=1
21
+ collected 387 items
22
+
23
+ tests/layers/moe_test.py ...........................................
24
+ tests/ops/binned_gather_test.py .....................
25
+ tests/ops/binned_scatter_test.py .....................
26
+ tests/ops/cumsum_test.py ................................
27
+ tests/ops/histogram_test.py ......................................................
28
+ tests/ops/padded_gather_test.py ......................................
29
+ tests/ops/padded_scatter_test.py ......................................................
30
+ tests/ops/replicate_test.py ..................................................................................
31
+ tests/ops/sort_test.py ..................
32
+ tests/ops/topology_test.py ....................
33
+ tests/test_mb_moe.py megablocks_moe module imported successfully.
34
+ Available functions: ['Arguments', 'MLP', 'MoE', 'ParallelDroplessMLP', 'ParallelMLP', 'SparseGLU', 'SparseMLP', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_megablocks_a4f6452_dirty', '_ops', 'argsort', 'backend', 'cumsum', 'dMoE', 'exclusive_cumsum', 'get_load_balancing_loss', 'grouped_gemm_util', 'histogram', 'inclusive_cumsum', 'indices', 'layers', 'ops', 'replicate_backward', 'replicate_forward', 'sort', 'torch']
35
+ .cumsum output: tensor([0, 1, 3, 6], device='cuda:0', dtype=torch.int16)
36
+ ...
37
+
38
+ ================ warnings summary ================
39
+ ...
40
+ -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
41
+ ======= 387 passed, 18 warnings in 54.63s ========
42
+ ```
build.toml CHANGED
@@ -10,6 +10,20 @@ src = [
10
 
11
  [kernel.megablocks]
12
  backend = "cuda"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  src = [
14
  "csrc/new_cumsum.h",
15
  "csrc/new_cumsum.cu",
@@ -22,9 +36,17 @@ src = [
22
  "csrc/new_sort.h",
23
  "csrc/new_sort.cu",
24
  ]
25
- depends = [ "torch", "cutlass_3_8" ]
26
 
27
  [test]
28
  python-git-packages = [
29
- { url = "https://github.com/stanford-futuredata/stk.git", rev = "7363137", sha256 = "0m6g5l9nlwaiwybg5j8dhnz159wdpabdnkzapnn3dsifxrsb59vz" }
 
 
 
 
 
 
 
 
 
30
  ]
 
10
 
11
  [kernel.megablocks]
12
  backend = "cuda"
13
+ cuda-capabilities = [
14
+ "7.0",
15
+ "7.2",
16
+ "7.5",
17
+ "8.0",
18
+ "8.6",
19
+ "8.7",
20
+ "8.9",
21
+ "9.0",
22
+ "10.0",
23
+ "10.1",
24
+ "12.0",
25
+ ]
26
+ depends = ["torch", "cutlass_3_8"]
27
  src = [
28
  "csrc/new_cumsum.h",
29
  "csrc/new_cumsum.cu",
 
36
  "csrc/new_sort.h",
37
  "csrc/new_sort.cu",
38
  ]
 
39
 
40
  [test]
41
  python-git-packages = [
42
+ { url = "https://github.com/stanford-futuredata/stk.git", rev = "7363137", sha256 = "0m6g5l9nlwaiwybg5j8dhnz159wdpabdnkzapnn3dsifxrsb59vz" },
43
+ { url = "https://github.com/mosaicml/composer.git", rev = "v0.9.0", sha256 = "ekJ5nE6JwYY6Ld9kIk72R/a3iI943Gd5yvAkBHQs5aI=" },
44
+ # { url = "https://github.com/tgale96/grouped_gemm.git", rev = "v0.3.0", sha256 = "sha256-fS6MuDj6yQ00CSzFrmAmM20/ccvtLJ1MFjfeqdwuPl8=" }
45
+ ]
46
+ python-packages = [
47
+ "tqdm",
48
+ "py-cpuinfo",
49
+ "importlib-metadata",
50
+ "torchmetrics",
51
+ # "yahp"
52
  ]
tests/conftest.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import os
5
+ from typing import List, Optional
6
+
7
+ import pytest
8
+ # from composer.utils import reproducibility
9
+
10
+ # Allowed options for pytest.mark.world_size()
11
+ WORLD_SIZE_OPTIONS = (1, 2)
12
+
13
+ # Enforce deterministic mode before any tests start.
14
+ # reproducibility.configure_deterministic_mode()
15
+
16
+ # TODO: allow plugind when deps resolved
17
+
18
+ # Add the path of any pytest fixture files you want to make global
19
+ pytest_plugins = [
20
+ # 'tests.fixtures.autouse',
21
+ 'tests.fixtures.fixtures',
22
+ ]
23
+
24
+
25
+ def _get_world_size(item: pytest.Item):
26
+ """Returns the world_size of a test, defaults to 1."""
27
+ _default = pytest.mark.world_size(1).mark
28
+ return item.get_closest_marker('world_size', default=_default).args[0]
29
+
30
+
31
+ def _get_option(
32
+ config: pytest.Config,
33
+ name: str,
34
+ default: Optional[str] = None,
35
+ ) -> str: # type: ignore
36
+ val = config.getoption(name)
37
+ if val is not None:
38
+ assert isinstance(val, str)
39
+ return val
40
+ val = config.getini(name)
41
+ if val == []:
42
+ val = None
43
+ if val is None:
44
+ if default is None:
45
+ pytest.fail(f'Config option {name} is not specified but is required',)
46
+ val = default
47
+ assert isinstance(val, str)
48
+ return val
49
+
50
+
51
+ def _add_option(
52
+ parser: pytest.Parser,
53
+ name: str,
54
+ help: str,
55
+ choices: Optional[list[str]] = None,
56
+ ):
57
+ parser.addoption(
58
+ f'--{name}',
59
+ default=None,
60
+ type=str,
61
+ choices=choices,
62
+ help=help,
63
+ )
64
+ parser.addini(
65
+ name=name,
66
+ help=help,
67
+ type='string',
68
+ default=None,
69
+ )
70
+
71
+
72
+ def pytest_collection_modifyitems(
73
+ config: pytest.Config,
74
+ items: List[pytest.Item],
75
+ ) -> None:
76
+ """Filter tests by world_size (for multi-GPU tests)"""
77
+ world_size = int(os.environ.get('WORLD_SIZE', '1'))
78
+ print(f'world_size={world_size}')
79
+
80
+ conditions = [
81
+ lambda item: _get_world_size(item) == world_size,
82
+ ]
83
+
84
+ # keep items that satisfy all conditions
85
+ remaining = []
86
+ deselected = []
87
+ for item in items:
88
+ if all(condition(item) for condition in conditions):
89
+ remaining.append(item)
90
+ else:
91
+ deselected.append(item)
92
+
93
+ if deselected:
94
+ config.hook.pytest_deselected(items=deselected)
95
+ items[:] = remaining
96
+
97
+
98
+ def pytest_addoption(parser: pytest.Parser) -> None:
99
+ _add_option(
100
+ parser,
101
+ 'seed',
102
+ help="""\
103
+ Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked
104
+ before each test.""",
105
+ )
106
+
107
+
108
+ def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
109
+ if exitstatus == 5:
110
+ session.exitstatus = 0 # Ignore no-test-ran errors
tests/fixtures/autouse.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import gc
5
+ import logging
6
+ import os
7
+
8
+ import composer
9
+ import pytest
10
+ import torch
11
+ from composer.devices import DeviceCPU, DeviceGPU
12
+ from composer.utils import dist, reproducibility
13
+
14
+
15
+ @pytest.fixture(autouse=True)
16
+ def clear_cuda_cache(request: pytest.FixtureRequest):
17
+ """Clear memory between GPU tests."""
18
+ marker = request.node.get_closest_marker('gpu')
19
+ if marker is not None and torch.cuda.is_available():
20
+ torch.cuda.empty_cache()
21
+ gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests
22
+
23
+
24
+ @pytest.fixture(autouse=True)
25
+ def reset_mlflow_tracking_dir():
26
+ """Reset MLFlow tracking dir so it doesn't persist across tests."""
27
+ try:
28
+ import mlflow
29
+ mlflow.set_tracking_uri(None) # type: ignore
30
+ except ModuleNotFoundError:
31
+ # MLFlow not installed
32
+ pass
33
+
34
+
35
+ @pytest.fixture(scope='session')
36
+ def cleanup_dist():
37
+ """Ensure all dist tests clean up resources properly."""
38
+ yield
39
+ # Avoid race condition where a test is still writing to a file on one rank
40
+ # while the file system is being torn down on another rank.
41
+ dist.barrier()
42
+
43
+
44
+ @pytest.fixture(autouse=True, scope='session')
45
+ def configure_dist(request: pytest.FixtureRequest):
46
+ # Configure dist globally when the world size is greater than 1,
47
+ # so individual tests that do not use the trainer
48
+ # do not need to worry about manually configuring dist.
49
+
50
+ if dist.get_world_size() == 1:
51
+ return
52
+
53
+ device = None
54
+
55
+ for item in request.session.items:
56
+ device = DeviceCPU() if item.get_closest_marker('gpu') is None else DeviceGPU()
57
+ break
58
+
59
+ assert device is not None
60
+
61
+ if not dist.is_initialized():
62
+ dist.initialize_dist(device, timeout=300.0)
63
+ # Hold PyTest until all ranks have reached this barrier. Ensure that no rank starts
64
+ # any test before other ranks are ready to start it, which could be a cause of random timeouts
65
+ # (e.g. rank 1 starts the next test while rank 0 is finishing up the previous test).
66
+ dist.barrier()
67
+
68
+
69
+ @pytest.fixture(autouse=True)
70
+ def set_log_levels():
71
+ """Ensures all log levels are set to DEBUG."""
72
+ logging.basicConfig()
73
+ logging.getLogger(composer.__name__).setLevel(logging.DEBUG)
74
+
75
+
76
+ @pytest.fixture(autouse=True)
77
+ def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch):
78
+ """Monkeypatch reproducibility.
79
+
80
+ Make get_random_seed to always return the rank zero seed, and set the random seed before each test to the rank local
81
+ seed.
82
+ """
83
+ monkeypatch.setattr(
84
+ reproducibility,
85
+ 'get_random_seed',
86
+ lambda: rank_zero_seed,
87
+ )
88
+ reproducibility.seed_all(rank_zero_seed + dist.get_global_rank())
89
+
90
+
91
+ @pytest.fixture(autouse=True)
92
+ def remove_run_name_env_var():
93
+ # Remove environment variables for run names in unit tests
94
+ composer_run_name = os.environ.get('COMPOSER_RUN_NAME')
95
+ run_name = os.environ.get('RUN_NAME')
96
+
97
+ if 'COMPOSER_RUN_NAME' in os.environ:
98
+ del os.environ['COMPOSER_RUN_NAME']
99
+ if 'RUN_NAME' in os.environ:
100
+ del os.environ['RUN_NAME']
101
+
102
+ yield
103
+
104
+ if composer_run_name is not None:
105
+ os.environ['COMPOSER_RUN_NAME'] = composer_run_name
106
+ if run_name is not None:
107
+ os.environ['RUN_NAME'] = run_name
tests/fixtures/fixtures.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import pytest
5
+
6
+ from tests.conftest import _get_option
7
+
8
+
9
+ @pytest.fixture
10
+ def rank_zero_seed(pytestconfig: pytest.Config) -> int:
11
+ """Read the rank_zero_seed from the CLI option."""
12
+ seed = _get_option(pytestconfig, 'seed', default='0')
13
+ return int(seed)
tests/layers/architectures.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from megablocks.layers.arguments import Arguments
8
+
9
+
10
+ class FFN(torch.nn.Module):
11
+
12
+ def __init__(self, args: Arguments):
13
+ super().__init__()
14
+ self.w1 = torch.nn.Parameter(
15
+ torch.empty(
16
+ args.hidden_size,
17
+ args.ffn_hidden_size,
18
+ device=args.device,
19
+ dtype=torch.float16 if args.fp16 else torch.float32,
20
+ ),
21
+ )
22
+ self.w2 = torch.nn.Parameter(
23
+ torch.empty(
24
+ args.ffn_hidden_size,
25
+ args.hidden_size,
26
+ device=args.device,
27
+ dtype=torch.float16 if args.fp16 else torch.float32,
28
+ ),
29
+ )
30
+
31
+ def forward(self, x):
32
+ return torch.matmul(
33
+ F.gelu(torch.matmul(x, self.w1), approximate='tanh'),
34
+ self.w2,
35
+ )
36
+
37
+
38
+ class GLU(FFN):
39
+
40
+ def __init__(self, args: Arguments):
41
+ super().__init__(args)
42
+ self.v1 = torch.nn.Parameter(
43
+ torch.empty(
44
+ args.hidden_size,
45
+ args.ffn_hidden_size,
46
+ device=args.device,
47
+ dtype=torch.float16 if args.fp16 else torch.float32,
48
+ ),
49
+ )
50
+
51
+ def forward(self, x):
52
+ x1 = F.gelu(torch.matmul(x, self.w1), approximate='tanh') * torch.matmul(x, self.v1)
53
+ return torch.matmul(x1, self.w2)
tests/layers/moe_test.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from functools import partial
5
+
6
+ import pytest
7
+ import torch
8
+
9
+ from megablocks.layers.arguments import Arguments
10
+ from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss
11
+ from megablocks.layers.router import batched_router_zloss, clear_router_zloss
12
+ from tests.layers.architectures import FFN
13
+
14
+ _FORWARD_TESTS = (
15
+ (16, 1024, 512, 1, 1),
16
+ (16, 1024, 512, 2, 1),
17
+ (16, 1024, 512, 4, 1),
18
+ (16, 1024, 512, 8, 1),
19
+ (8, 2048, 512, 1, 1),
20
+ (8, 2048, 512, 2, 1),
21
+ (8, 2048, 512, 4, 1),
22
+ (16, 1024, 512, 2, 2),
23
+ (16, 1024, 512, 4, 2),
24
+ (16, 1024, 512, 4, 4),
25
+ (16, 1024, 512, 8, 2),
26
+ (16, 1024, 512, 8, 4),
27
+ (16, 1024, 512, 8, 8),
28
+ )
29
+
30
+ _DENSE_TESTS = (
31
+ (16, 1024, 512),
32
+ (8, 2048, 512),
33
+ )
34
+
35
+
36
+ def construct_moe(
37
+ hidden_size: int,
38
+ ffn_hidden_size: int,
39
+ moe_num_experts: int = 1,
40
+ moe_capacity_factor: int = 1,
41
+ moe_top_k: int = 1,
42
+ moe_zloss_weight: float = 0,
43
+ ):
44
+ # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
45
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
46
+ try:
47
+ import triton
48
+ if triton.__version__ >= '3.2.0':
49
+ pytest.skip('Sparse MLP is not supported with triton >=3.2.0')
50
+ except ImportError:
51
+ pass
52
+
53
+ init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
54
+ args = Arguments(
55
+ hidden_size=hidden_size,
56
+ ffn_hidden_size=ffn_hidden_size,
57
+ moe_num_experts=moe_num_experts,
58
+ moe_capacity_factor=moe_capacity_factor,
59
+ moe_top_k=moe_top_k,
60
+ init_method=init_method,
61
+ moe_zloss_weight=moe_zloss_weight,
62
+ )
63
+
64
+ mlp = FFN(args)
65
+ moe_mlp = MoE(args)
66
+
67
+ mlp.cuda(torch.cuda.current_device()).half()
68
+ moe_mlp.cuda(torch.cuda.current_device()).half()
69
+
70
+ # Set the baseline parameters to match exactly.
71
+ if moe_num_experts == 1:
72
+ with torch.no_grad():
73
+ mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze())
74
+ mlp.w2.copy_(moe_mlp.experts.mlp.w2.squeeze())
75
+ return args, mlp, moe_mlp
76
+
77
+
78
+ @pytest.mark.gpu
79
+ @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
80
+ def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int):
81
+ x = torch.randn(sl, bs, hs).half().cuda()
82
+
83
+ _, _, layer = construct_moe(
84
+ hidden_size=hs,
85
+ ffn_hidden_size=hs * 2,
86
+ moe_num_experts=num_experts,
87
+ moe_top_k=top_k,
88
+ )
89
+
90
+ out, _ = layer(x)
91
+ assert out.shape == x.shape
92
+ clear_load_balancing_loss()
93
+
94
+
95
+ @pytest.mark.gpu
96
+ @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
97
+ def test_moe_forward_backward(
98
+ bs: int,
99
+ sl: int,
100
+ hs: int,
101
+ num_experts: int,
102
+ top_k: int,
103
+ ):
104
+ x = torch.randn(sl, bs, hs).half().cuda()
105
+ x.requires_grad_(True)
106
+
107
+ args, _, layer = construct_moe(
108
+ hidden_size=hs,
109
+ ffn_hidden_size=hs * 2,
110
+ moe_num_experts=num_experts,
111
+ moe_top_k=top_k,
112
+ )
113
+
114
+ out, _ = layer(x)
115
+ assert out.shape == x.shape
116
+
117
+ loss = out.sum() + batched_load_balancing_loss(args)
118
+ loss.backward()
119
+ layer.zero_grad(set_to_none=True)
120
+ x.grad = None
121
+ clear_load_balancing_loss()
122
+
123
+
124
+ @pytest.mark.gpu
125
+ @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
126
+ def test_moe_forward_backward_with_zloss(
127
+ bs: int,
128
+ sl: int,
129
+ hs: int,
130
+ num_experts: int,
131
+ top_k: int,
132
+ ):
133
+ x = torch.randn(sl, bs, hs).half().cuda()
134
+ x.requires_grad_(True)
135
+
136
+ args, _, layer = construct_moe(
137
+ hidden_size=hs,
138
+ ffn_hidden_size=hs * 2,
139
+ moe_num_experts=num_experts,
140
+ moe_top_k=top_k,
141
+ moe_zloss_weight=1e-3,
142
+ )
143
+
144
+ out, _ = layer(x)
145
+ assert out.shape == x.shape
146
+
147
+ loss = out.sum() + batched_load_balancing_loss(args)
148
+ loss.backward()
149
+ layer.zero_grad(set_to_none=True)
150
+ x.grad = None
151
+ clear_load_balancing_loss()
152
+ clear_router_zloss()
153
+
154
+
155
+ @pytest.mark.gpu
156
+ @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
157
+ def test_moe_forward_vs_dense(bs: int, sl: int, hs: int):
158
+ x = torch.randn(sl, bs, hs).half().cuda()
159
+
160
+ _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2)
161
+
162
+ expected_out = mlp(x)
163
+ out, _ = moe_mlp(x)
164
+ assert out.shape == x.shape == expected_out.shape
165
+ assert torch.allclose(out, expected_out)
166
+ clear_load_balancing_loss()
167
+
168
+
169
+ @pytest.mark.gpu
170
+ @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
171
+ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int):
172
+ x = torch.randn(sl, bs, hs).half().cuda()
173
+ x.requires_grad_(True)
174
+
175
+ _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2)
176
+
177
+ out, _ = moe_mlp(x)
178
+ loss = out.sum()
179
+ loss.backward()
180
+ w1_grad = moe_mlp.experts.mlp.w1.grad.detach().squeeze()
181
+ w2_grad = moe_mlp.experts.mlp.w2.grad.detach().squeeze()
182
+ moe_mlp.zero_grad(set_to_none=True)
183
+ x.grad = None
184
+ clear_load_balancing_loss()
185
+
186
+ expected_out = mlp(x)
187
+ expected_loss = expected_out.sum()
188
+ expected_loss.backward()
189
+ expected_w1_grad = mlp.w1.grad.detach()
190
+ expected_w2_grad = mlp.w2.grad.detach()
191
+ mlp.zero_grad(set_to_none=True)
192
+ x.grad = None
193
+
194
+ # Verify the gradients match.
195
+ assert w1_grad.shape == expected_w1_grad.shape
196
+ assert w2_grad.shape == expected_w2_grad.shape
197
+ assert torch.allclose(w1_grad, expected_w1_grad)
198
+ assert torch.allclose(w2_grad, expected_w2_grad)
199
+ clear_load_balancing_loss()
tests/ops/binned_gather_test.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ from megablocks import ops
9
+
10
+ BINNED_GATHER_TESTS = (
11
+ (4, 2, 2, 1),
12
+ (4, 2, 2, 2),
13
+ (4, 2, 2, 4),
14
+ (1024, 1536, 4, 1),
15
+ (1024, 1536, 4, 2),
16
+ (1024, 1536, 4, 4),
17
+ (1024, 1536, 64, 1),
18
+ (1024, 1536, 64, 2),
19
+ (1024, 1536, 64, 4),
20
+ (1024, 1536, 128, 1),
21
+ (1024, 1536, 128, 2),
22
+ (1024, 1536, 128, 4),
23
+ (16384, 768, 4, 1),
24
+ (16384, 768, 4, 2),
25
+ (16384, 768, 4, 4),
26
+ (16384, 768, 64, 1),
27
+ (16384, 768, 64, 2),
28
+ (16384, 768, 64, 4),
29
+ (16384, 768, 128, 1),
30
+ (16384, 768, 128, 2),
31
+ (16384, 768, 128, 4),
32
+ )
33
+
34
+
35
+ @pytest.mark.gpu
36
+ @pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), BINNED_GATHER_TESTS)
37
+ def test_binned_gather(sl: int, hs: int, ne: int, top_k: int):
38
+ # NOTE: Capacity factor == 1.
39
+ ec = (sl * top_k) // ne
40
+
41
+ # Create the data and indices.
42
+ x = torch.randn((sl, hs)).cuda().half()
43
+
44
+ # Randomly assign tokens to experts.
45
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
46
+ _, indices = ops.sort(top_expert)
47
+ bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0)
48
+
49
+ def binned_gather(
50
+ x: torch.Tensor,
51
+ indices: torch.Tensor,
52
+ bins: torch.Tensor,
53
+ ec: int,
54
+ top_k: int,
55
+ ):
56
+ x = x.cpu().numpy()
57
+ indices = indices.cpu().numpy()
58
+ bins = bins.cpu().numpy()
59
+ start = 0
60
+ out = np.zeros((ne, ec, hs))
61
+ for i in range(ne):
62
+ end = bins[i]
63
+ for j in range(min(ec, end - start)):
64
+ index = indices[start + j] // top_k
65
+ out[i, j, :] = x[index, :]
66
+ start = end
67
+ return torch.from_numpy(out).cuda().half()
68
+
69
+ out = ops.binned_gather(x, indices, bins, ec, top_k)
70
+ expected_out = binned_gather(x, indices, bins, ec, top_k)
71
+ assert torch.all(torch.eq(out, expected_out))
tests/ops/binned_scatter_test.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ from megablocks import ops
9
+
10
+ _BINNED_SCATTER_TESTS = (
11
+ (4, 2, 2, 1),
12
+ (4, 2, 2, 2),
13
+ (4, 2, 2, 4),
14
+ (1024, 1536, 4, 1),
15
+ (1024, 1536, 4, 2),
16
+ (1024, 1536, 4, 4),
17
+ (1024, 1536, 64, 1),
18
+ (1024, 1536, 64, 2),
19
+ (1024, 1536, 64, 4),
20
+ (1024, 1536, 128, 1),
21
+ (1024, 1536, 128, 2),
22
+ (1024, 1536, 128, 4),
23
+ (16384, 768, 4, 1),
24
+ (16384, 768, 4, 2),
25
+ (16384, 768, 4, 4),
26
+ (16384, 768, 64, 1),
27
+ (16384, 768, 64, 2),
28
+ (16384, 768, 64, 4),
29
+ (16384, 768, 128, 1),
30
+ (16384, 768, 128, 2),
31
+ (16384, 768, 128, 4),
32
+ )
33
+
34
+
35
+ @pytest.mark.gpu
36
+ @pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), _BINNED_SCATTER_TESTS)
37
+ def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int):
38
+ # NOTE: Capacity factor == 1.
39
+ ec = (sl * top_k) // ne
40
+
41
+ # Create the data and indices.
42
+ x = torch.randn((sl, hs)).cuda().half()
43
+
44
+ # Randomly assign tokens to experts.
45
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
46
+ _, indices = ops.sort(top_expert)
47
+ bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0)
48
+
49
+ # Sample weights for the scatter reduce.
50
+ weights = torch.rand((sl * top_k,)).cuda().half()
51
+
52
+ x = ops.binned_gather(x, indices, bins, ec, top_k)
53
+
54
+ def binned_scatter(
55
+ x: torch.Tensor,
56
+ indices: torch.Tensor,
57
+ weights: torch.Tensor,
58
+ bins: torch.Tensor,
59
+ top_k: int,
60
+ ):
61
+ x = x.cpu().numpy()
62
+ indices = indices.cpu().numpy()
63
+ weights = weights.cpu().numpy()
64
+ bins = bins.cpu().numpy()
65
+ start = 0
66
+ out = np.zeros((sl, hs))
67
+ for i in range(ne):
68
+ end = bins[i]
69
+ for j in range(min(ec, end - start)):
70
+ index = indices[start + j]
71
+ scale = weights[index]
72
+ index //= top_k
73
+
74
+ out[index, :] += scale * x[i, j, :]
75
+ start = end
76
+ return torch.from_numpy(out).cuda().half()
77
+
78
+ out = ops.binned_scatter(x, indices, weights, bins, top_k)
79
+ expected_out = binned_scatter(x, indices, weights, bins, top_k)
80
+
81
+ # NOTE: We need to check approximate equality because the
82
+ # scatter reduce uses atomics.
83
+ assert np.testing.assert_allclose(
84
+ out.cpu(),
85
+ expected_out.cpu(),
86
+ rtol=5e-3,
87
+ ) is None
tests/ops/cumsum_test.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import pytest
5
+ import torch
6
+
7
+ from megablocks import ops
8
+
9
+ CUMSUM_TESTS = (
10
+ (1, 32),
11
+ (2, 32),
12
+ (2, 1024),
13
+ (4, 1024),
14
+ (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (2, 16384),
20
+ (4, 16384),
21
+ (8, 16384),
22
+ (16, 16384),
23
+ (32, 16384),
24
+ (64, 16384),
25
+ (128, 16384),
26
+ )
27
+
28
+
29
+ @pytest.mark.gpu
30
+ @pytest.mark.parametrize(('n', 'm'), CUMSUM_TESTS)
31
+ def test_exclusive_cumsum(n: int, m: int):
32
+ x = torch.randint(0, 2, (n, m)).long().cuda()
33
+ out = ops.exclusive_cumsum(x, 1) * x
34
+ expected_out = (torch.cumsum(x, dim=1) - 1) * x
35
+ assert torch.all(torch.eq(out, expected_out))
36
+
37
+
38
+ @pytest.mark.gpu
39
+ @pytest.mark.parametrize(('n', 'm'), CUMSUM_TESTS)
40
+ def test_inclusive_cumsum(n: int, m: int):
41
+ x = torch.randint(0, 2, (n, m)).long().cuda()
42
+ out = ops.inclusive_cumsum(x, 1)
43
+ expected_out = torch.cumsum(x, dim=1)
44
+ assert torch.all(torch.eq(out, expected_out))
tests/ops/histogram_test.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import pytest
5
+ import torch
6
+
7
+ from megablocks import ops
8
+
9
+ _HISTOGRAM_TESTS = (
10
+ (1, 32, torch.int16, 128),
11
+ (1, 1024, torch.int16, 128),
12
+ (1, 16384, torch.int16, 128),
13
+ (1, 32, torch.int32, 128),
14
+ (1, 1024, torch.int32, 128),
15
+ (1, 16384, torch.int32, 128),
16
+ (1, 32, torch.int64, 128),
17
+ (1, 1024, torch.int64, 128),
18
+ (1, 16384, torch.int64, 128),
19
+ (1, 32, torch.int16, 1024),
20
+ (1, 1024, torch.int16, 1024),
21
+ (1, 16384, torch.int16, 1024),
22
+ (1, 32, torch.int32, 1024),
23
+ (1, 1024, torch.int32, 1024),
24
+ (1, 16384, torch.int32, 1024),
25
+ (1, 32, torch.int64, 1024),
26
+ (1, 1024, torch.int64, 1024),
27
+ (1, 16384, torch.int64, 1024),
28
+ (2, 32, torch.int16, 128),
29
+ (2, 1024, torch.int16, 128),
30
+ (2, 16384, torch.int16, 128),
31
+ (2, 32, torch.int32, 128),
32
+ (2, 1024, torch.int32, 128),
33
+ (2, 16384, torch.int32, 128),
34
+ (2, 32, torch.int64, 128),
35
+ (2, 1024, torch.int64, 128),
36
+ (2, 16384, torch.int64, 128),
37
+ (2, 32, torch.int16, 1024),
38
+ (2, 1024, torch.int16, 1024),
39
+ (2, 16384, torch.int16, 1024),
40
+ (2, 32, torch.int32, 1024),
41
+ (2, 1024, torch.int32, 1024),
42
+ (2, 16384, torch.int32, 1024),
43
+ (2, 32, torch.int64, 1024),
44
+ (2, 1024, torch.int64, 1024),
45
+ (2, 16384, torch.int64, 1024),
46
+ (8, 32, torch.int16, 128),
47
+ (8, 1024, torch.int16, 128),
48
+ (8, 16384, torch.int16, 128),
49
+ (8, 32, torch.int32, 128),
50
+ (8, 1024, torch.int32, 128),
51
+ (8, 16384, torch.int32, 128),
52
+ (8, 32, torch.int64, 128),
53
+ (8, 1024, torch.int64, 128),
54
+ (8, 16384, torch.int64, 128),
55
+ (8, 32, torch.int16, 1024),
56
+ (8, 1024, torch.int16, 1024),
57
+ (8, 16384, torch.int16, 1024),
58
+ (8, 32, torch.int32, 1024),
59
+ (8, 1024, torch.int32, 1024),
60
+ (8, 16384, torch.int32, 1024),
61
+ (8, 32, torch.int64, 1024),
62
+ (8, 1024, torch.int64, 1024),
63
+ (8, 16384, torch.int64, 1024),
64
+ )
65
+
66
+
67
+ # Override the seed_all fixture in autouse.py because
68
+ # _histc_cuda does not have a deterministic implementation
69
+ @pytest.fixture()
70
+ def seed_all():
71
+ torch.use_deterministic_algorithms(False)
72
+ return
73
+
74
+
75
+ @pytest.mark.gpu
76
+ @pytest.mark.parametrize(('m', 'n', 'dtype', 'max_val'), _HISTOGRAM_TESTS)
77
+ def test_histogram(m: int, n: int, dtype: torch.dtype, max_val: int):
78
+ x = torch.randint(0, max_val, (m, n)).cuda().to(dtype)
79
+
80
+ out = ops.histogram(x, max_val)
81
+ expected_out = torch.stack([torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)])
82
+ assert torch.all(torch.eq(out, expected_out))
tests/ops/padded_gather_test.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ from megablocks import ops
9
+
10
+ PADDED_GATHER_TESTS = (
11
+ (4, 2, 2, 1),
12
+ (4, 2, 2, 2),
13
+ (1024, 1, 4, 1),
14
+ (1024, 1, 4, 2),
15
+ (1024, 1, 4, 4),
16
+ (1024, 1, 64, 1),
17
+ (1024, 1, 64, 2),
18
+ (1024, 1, 64, 4),
19
+ (1024, 1, 128, 1),
20
+ (1024, 1, 128, 2),
21
+ (1024, 1, 128, 4),
22
+ (1024, 1536, 4, 1),
23
+ (1024, 1536, 4, 2),
24
+ (1024, 1536, 4, 4),
25
+ (1024, 1536, 64, 1),
26
+ (1024, 1536, 64, 2),
27
+ (1024, 1536, 64, 4),
28
+ (1024, 1536, 128, 1),
29
+ (1024, 1536, 128, 2),
30
+ (1024, 1536, 128, 4),
31
+ (16384, 768, 4, 1),
32
+ (16384, 768, 4, 2),
33
+ (16384, 768, 4, 4),
34
+ (16384, 768, 64, 1),
35
+ (16384, 768, 64, 2),
36
+ (16384, 768, 64, 4),
37
+ (16384, 768, 128, 1),
38
+ (16384, 768, 128, 2),
39
+ (16384, 768, 128, 4),
40
+ (16384, 1, 4, 1),
41
+ (16384, 1, 4, 2),
42
+ (16384, 1, 4, 4),
43
+ (16384, 1, 64, 1),
44
+ (16384, 1, 64, 2),
45
+ (16384, 1, 64, 4),
46
+ (16384, 1, 128, 1),
47
+ (16384, 1, 128, 2),
48
+ (16384, 1, 128, 4),
49
+ )
50
+
51
+
52
+ @pytest.mark.gpu
53
+ @pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), PADDED_GATHER_TESTS)
54
+ def testPaddedGather(sl: int, hs: int, ne: int, top_k: int):
55
+ # Create the data and indices.
56
+ x = torch.randn((sl, hs)).cuda().half()
57
+
58
+ # Randomly assign tokens to experts.
59
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
60
+ bin_ids, indices = ops.sort(top_expert)
61
+ tokens_per_expert = ops.histogram(top_expert, ne)
62
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
63
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
64
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
+
66
+ def padded_gather(
67
+ x: torch.Tensor,
68
+ indices: torch.Tensor,
69
+ bin_ids: torch.Tensor,
70
+ bins: torch.Tensor,
71
+ padded_bins: torch.Tensor,
72
+ top_k: int,
73
+ ):
74
+ x = x.cpu().numpy()
75
+ indices = indices.cpu().numpy()
76
+ bin_ids = bin_ids.cpu().numpy()
77
+ bins = bins.cpu().numpy()
78
+ padded_bins = padded_bins.cpu().numpy()
79
+
80
+ out = np.zeros((padded_bins[-1], hs))
81
+ in_idx = 0
82
+ for i, end in enumerate(bins):
83
+ out_idx = 0 if i == 0 else padded_bins[i - 1]
84
+ end = bins[i]
85
+ while in_idx < end:
86
+ load_idx = indices[in_idx] // top_k
87
+ out[out_idx, :] = x[load_idx, :]
88
+ in_idx += 1
89
+ out_idx += 1
90
+ return torch.from_numpy(out).cuda().half()
91
+
92
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
93
+ expected_out = padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
94
+ assert torch.all(torch.eq(out, expected_out))
tests/ops/padded_scatter_test.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ from megablocks import ops
9
+
10
+ PADDED_SCATTER_TESTS = [
11
+ (4, 2, 2, 2),
12
+ (4, 2, 2, 1),
13
+ (4, 2, 2, 1),
14
+ (4, 2, 2, 1),
15
+ (4, 2, 2, 2),
16
+ (4, 2, 2, 2),
17
+ (1024, 1, 4, 1),
18
+ (1024, 1, 4, 2),
19
+ (1024, 1, 4, 4),
20
+ (1024, 1, 4, 1),
21
+ (1024, 1, 4, 2),
22
+ (1024, 1, 4, 4),
23
+ (1024, 1, 4, 1),
24
+ (1024, 1, 4, 2),
25
+ (1024, 1, 4, 4),
26
+ (1024, 1, 64, 1),
27
+ (1024, 1, 64, 2),
28
+ (1024, 1, 64, 4),
29
+ (1024, 1, 128, 1),
30
+ (1024, 1, 128, 2),
31
+ (1024, 1, 128, 4),
32
+ (1024, 1536, 4, 1),
33
+ (1024, 1536, 4, 2),
34
+ (1024, 1536, 4, 4),
35
+ (1024, 1536, 4, 4),
36
+ (1024, 1536, 4, 4),
37
+ (1024, 1536, 64, 1),
38
+ (1024, 1536, 64, 2),
39
+ (1024, 1536, 64, 4),
40
+ (1024, 1536, 128, 1),
41
+ (1024, 1536, 128, 2),
42
+ (1024, 1536, 128, 4),
43
+ (1024, 1536, 128, 1),
44
+ (1024, 1536, 128, 1),
45
+ (16384, 768, 4, 1),
46
+ (16384, 768, 4, 2),
47
+ (16384, 768, 4, 4),
48
+ (16384, 768, 64, 1),
49
+ (16384, 768, 64, 2),
50
+ (16384, 768, 64, 4),
51
+ (16384, 768, 128, 1),
52
+ (16384, 768, 128, 2),
53
+ (16384, 768, 128, 4),
54
+ (16384, 1, 4, 1),
55
+ (16384, 1, 4, 2),
56
+ (16384, 1, 4, 4),
57
+ (16384, 1, 64, 1),
58
+ (16384, 1, 64, 2),
59
+ (16384, 1, 64, 4),
60
+ (16384, 1, 128, 1),
61
+ (16384, 1, 128, 2),
62
+ (16384, 1, 128, 4),
63
+ (16384, 1, 128, 2),
64
+ (16384, 1, 128, 2),
65
+ ]
66
+
67
+
68
+ def _to_numpy(x: torch.Tensor) -> np.ndarray:
69
+ return x.detach().cpu().numpy()
70
+
71
+
72
+ @pytest.mark.gpu
73
+ @pytest.mark.parametrize((
74
+ 'sl',
75
+ 'hs',
76
+ 'ne',
77
+ 'top_k',
78
+ ), PADDED_SCATTER_TESTS)
79
+ def testPaddedScatter(sl: int, hs: int, ne: int, top_k: int):
80
+ # Create the data and indices.
81
+ x = torch.randn((sl, hs), requires_grad=True).cuda().half()
82
+
83
+ # Randomly assign tokens to experts.
84
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
85
+ bin_ids, indices = ops.sort(top_expert)
86
+ tokens_per_expert = ops.histogram(top_expert, ne)
87
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
+
91
+ # Sample weights for the scatter reduce.
92
+ weights = torch.rand((sl * top_k,), requires_grad=True).cuda().half()
93
+
94
+ # Gather the data to prepare for backwards.
95
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
96
+
97
+ def padded_scatter(
98
+ x: torch.Tensor,
99
+ indices: torch.Tensor,
100
+ bin_ids: torch.Tensor,
101
+ weights: torch.Tensor,
102
+ bins: torch.Tensor,
103
+ padded_bins: torch.Tensor,
104
+ top_k: int,
105
+ ):
106
+ x = x.detach().cpu().numpy()
107
+ indices: np.ndarray = _to_numpy(indices)
108
+ bin_ids: np.ndarray = _to_numpy(bin_ids)
109
+ weights: np.ndarray = _to_numpy(weights)
110
+ bins: np.ndarray = _to_numpy(bins)
111
+ padded_bins: np.ndarray = _to_numpy(padded_bins)
112
+
113
+ out = np.zeros((indices.shape[0] // top_k, hs))
114
+ out_idx = 0
115
+ for i in range(len(bins)):
116
+ in_idx = 0 if i == 0 else padded_bins[i - 1]
117
+ end = bins[i]
118
+ while out_idx < end:
119
+ store_idx = indices[out_idx]
120
+ scale = weights[store_idx]
121
+ store_idx //= top_k
122
+
123
+ out[store_idx, :] += scale * x[in_idx, :]
124
+ out_idx += 1
125
+ in_idx += 1
126
+ return torch.from_numpy(out).cuda().half()
127
+
128
+ out = ops.padded_scatter(
129
+ x,
130
+ indices,
131
+ bin_ids,
132
+ weights,
133
+ bins,
134
+ padded_bins,
135
+ top_k,
136
+ )
137
+ expected_out = padded_scatter(
138
+ x,
139
+ indices,
140
+ bin_ids,
141
+ weights,
142
+ bins,
143
+ padded_bins,
144
+ top_k,
145
+ )
146
+
147
+ out.backward(torch.randn_like(out)) # sanity check backward pass
148
+
149
+ # NOTE: We need to check approximate equality because the scatter reduce uses atomics.
150
+ # np.testing.assert_allclose returns `None` if no error and raises an AssertionError if an error exists
151
+ assert np.testing.assert_allclose(
152
+ _to_numpy(out),
153
+ _to_numpy(expected_out),
154
+ rtol=5e-3,
155
+ ) is None
tests/ops/replicate_test.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ try:
9
+ from megablocks._ops import ops as backend # type: ignore
10
+ except ModuleNotFoundError as e:
11
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
12
+
13
+ from megablocks import ops
14
+
15
+
16
+ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
17
+ return x.view(1) if not len(x.size()) else x
18
+
19
+
20
+ REPLICATE_TESTS = [
21
+ (8, 1, 1),
22
+ (8, 2, 1),
23
+ (8, 4, 1),
24
+ (8, 8, 1),
25
+ (8, 2, 2),
26
+ (8, 4, 2),
27
+ (8, 8, 2),
28
+ (8, 2, 4),
29
+ (8, 4, 4),
30
+ (8, 8, 4),
31
+ (8, 2, 8),
32
+ (8, 4, 8),
33
+ (8, 8, 8),
34
+ (16384, 2, 1),
35
+ (16384, 4, 1),
36
+ (16384, 8, 1),
37
+ (16384, 16, 1),
38
+ (16384, 32, 1),
39
+ (16384, 64, 1),
40
+ (16384, 128, 1),
41
+ (16384, 2, 2),
42
+ (16384, 4, 2),
43
+ (16384, 8, 2),
44
+ (16384, 16, 2),
45
+ (16384, 32, 2),
46
+ (16384, 64, 2),
47
+ (16384, 128, 2),
48
+ (16384, 2, 4),
49
+ (16384, 4, 4),
50
+ (16384, 8, 4),
51
+ (16384, 16, 4),
52
+ (16384, 32, 4),
53
+ (16384, 64, 4),
54
+ (16384, 128, 4),
55
+ (16384, 2, 8),
56
+ (16384, 4, 8),
57
+ (16384, 8, 8),
58
+ (16384, 16, 8),
59
+ (16384, 32, 8),
60
+ (16384, 64, 8),
61
+ (16384, 128, 8),
62
+ ]
63
+
64
+
65
+ @pytest.mark.gpu
66
+ @pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
67
+ def test_replicate(tokens: int, num_centers: int, top_k: int):
68
+ tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
69
+ tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
70
+ bins = ops.inclusive_cumsum(tokens_per_center, 0)
71
+ bins = promote_scalar(bins)
72
+ center_weights = torch.randn(top_k, num_centers).cuda().half()
73
+
74
+ def replicate(x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
75
+ x = x.cpu().numpy()
76
+ bins = bins.cpu().numpy()
77
+ out = np.zeros((x.shape[0], num_outputs))
78
+ for batch_idx in range(x.shape[0]):
79
+ start = 0
80
+ for i, end in enumerate(bins):
81
+ value = x[batch_idx, i]
82
+ while start < end:
83
+ out[batch_idx, start] = value
84
+ start += 1
85
+ return torch.from_numpy(out).cuda().half()
86
+
87
+ out = ops.replicate(center_weights, bins, tokens)
88
+ expected_out = replicate(center_weights, bins, tokens)
89
+ assert torch.all(torch.eq(out, expected_out))
90
+
91
+
92
+ @pytest.mark.gpu
93
+ @pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
94
+ def test_replicate_backward(tokens: int, num_centers: int, top_k: int):
95
+ tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
96
+ tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
97
+ bins = ops.inclusive_cumsum(tokens_per_center, 0)
98
+ bins = promote_scalar(bins)
99
+ center_weights = torch.randn(top_k, num_centers).cuda().half()
100
+
101
+ grad = ops.replicate(center_weights, bins, tokens)
102
+
103
+ out = torch.empty_like(center_weights)
104
+ backend.replicate_backward(grad, bins, out)
105
+ expected_out = center_weights * tokens_per_center.view([1, num_centers])
106
+
107
+ # NOTE: This floating-point reduction could be a problem for training stability and accuracy.
108
+ assert torch.allclose(out, expected_out, rtol=1e-2)
tests/ops/sort_test.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Dict, Optional, Union
5
+
6
+ import numpy as np
7
+ import pytest
8
+ import torch
9
+
10
+ from megablocks import ops
11
+
12
+ SORT_TESTS = [
13
+ (32, torch.int16, None),
14
+ (1024, torch.int16, None),
15
+ (16384, torch.int16, None),
16
+ (32, torch.int32, None),
17
+ (1024, torch.int32, None),
18
+ (16384, torch.int32, None),
19
+ (32, torch.int64, None),
20
+ (1024, torch.int64, None),
21
+ (16384, torch.int64, None),
22
+ (32, torch.int16, 128),
23
+ (1024, torch.int16, 128),
24
+ (16384, torch.int16, 128),
25
+ (32, torch.int32, 128),
26
+ (1024, torch.int32, 128),
27
+ (16384, torch.int32, 128),
28
+ (32, torch.int64, 128),
29
+ (1024, torch.int64, 128),
30
+ (16384, torch.int64, 128),
31
+ ]
32
+
33
+
34
+ def torch_to_numpy_dtype(dtype: torch.dtype,) -> Union[np.int16, np.int32, np.int64]:
35
+ types: Dict[torch.dtype, Union[np.int16, np.int32, np.int64]] = {
36
+ torch.int16: np.int16,
37
+ torch.int32: np.int32,
38
+ torch.int64: np.int64,
39
+ }
40
+ return types[dtype]
41
+
42
+
43
+ @pytest.mark.gpu
44
+ @pytest.mark.parametrize(
45
+ ('n', 'dtype', 'max_val'),
46
+ SORT_TESTS,
47
+ )
48
+ def test_sort(n: int, dtype: torch.dtype, max_val: Optional[int]):
49
+ if max_val is None:
50
+ max_val = np.iinfo(torch_to_numpy_dtype(dtype)).max
51
+ end_bit = int(np.ceil(np.log2(max_val)))
52
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
53
+
54
+ out, indices = ops.sort(x, end_bit)
55
+ expected_out, expected_indices = torch.sort(x)
56
+ assert torch.all(torch.eq(out, expected_out))
57
+
58
+ # NOTE: The indices can be in different order depending
59
+ # on sort stability if multiple values in the array are
60
+ # equal.
61
+ data = torch.empty_like(x)
62
+ data.scatter_(0, indices.long(), out)
63
+ expected_data = torch.empty_like(x)
64
+ expected_data.scatter_(0, expected_indices, expected_out)
65
+ assert torch.all(torch.eq(data, expected_data))
tests/ops/topology_test.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ from megablocks import ops
9
+
10
+ TOPOLOGY_TESTS = (
11
+ (1024, 1536, 2),
12
+ (1024, 1536, 4),
13
+ (1024, 1536, 8),
14
+ (1024, 1536, 16),
15
+ (1024, 1536, 32),
16
+ (1024, 1536, 64),
17
+ (1024, 1536, 128),
18
+ (1024, 1536, 256),
19
+ (1024, 1536, 512),
20
+ (16384, 768, 2),
21
+ (16384, 768, 4),
22
+ (16384, 768, 8),
23
+ (16384, 768, 16),
24
+ (16384, 768, 32),
25
+ (16384, 768, 64),
26
+ (16384, 768, 128),
27
+ (16384, 768, 256),
28
+ (16384, 768, 512),
29
+ (16384, 768, 1024),
30
+ (8, 14336, 8),
31
+ )
32
+
33
+
34
+ @pytest.mark.gpu
35
+ @pytest.mark.parametrize(('sl', 'hs', 'ne'), TOPOLOGY_TESTS)
36
+ def test_topology(sl: int, hs: int, ne: int):
37
+ # Create the data and indices.
38
+ blocking = 128
39
+ assert hs % blocking == 0
40
+
41
+ # Randomly assign tokens to experts.
42
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
43
+ tokens_per_expert = ops.histogram(top_expert, ne)
44
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, blocking)
45
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
46
+
47
+ # Dimensions for the output indices.
48
+ output_block_rows = int(padded_bins[-1]) // blocking
49
+ output_block_columns = hs // blocking
50
+
51
+ def topology(
52
+ padded_bins: torch.Tensor,
53
+ blocking: torch.Tensor,
54
+ rows: int,
55
+ columns: int,
56
+ ):
57
+ padded_bins = padded_bins.cpu().numpy()
58
+
59
+ out = np.zeros([rows * columns])
60
+ start = 0
61
+ for i in range(padded_bins.shape[0]):
62
+ end = padded_bins[i] // blocking
63
+ while start < end:
64
+ for j in range(columns):
65
+ out[start * columns + j] = j + i * columns
66
+ start += 1
67
+ return torch.from_numpy(out).cuda().short()
68
+
69
+ out = ops.topology(
70
+ padded_bins,
71
+ blocking,
72
+ output_block_rows,
73
+ output_block_columns,
74
+ )
75
+ expected_out = topology(
76
+ padded_bins,
77
+ blocking,
78
+ output_block_rows,
79
+ output_block_columns,
80
+ )
81
+ assert torch.all(torch.eq(out, expected_out))
tests/test_mb_moe.py CHANGED
@@ -1,6 +1,48 @@
 
1
  import megablocks
2
 
3
  def test_import():
4
  """Simple test to check if the module can be imported."""
5
  print("megablocks_moe module imported successfully.")
6
  print("Available functions:", dir(megablocks))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import megablocks
3
 
4
  def test_import():
5
  """Simple test to check if the module can be imported."""
6
  print("megablocks_moe module imported successfully.")
7
  print("Available functions:", dir(megablocks))
8
+
9
+ expected_functions = [
10
+ "Arguments", "MLP", "MoE", "ParallelDroplessMLP", "ParallelMLP",
11
+ "SparseGLU", "SparseMLP", "argsort",
12
+ "backend", "cumsum", "dMoE", "exclusive_cumsum",
13
+ "get_load_balancing_loss", "grouped_gemm_util", "histogram",
14
+ "inclusive_cumsum", "indices", "layers", "ops", "replicate_backward",
15
+ "replicate_forward", "sort", "torch"
16
+ ]
17
+
18
+ # Check if all expected functions are available
19
+ for func in expected_functions:
20
+ assert func in dir(megablocks), f"Missing function: {func}"
21
+
22
+ # exclusive_cumsum
23
+ def test_exclusive_cumsum():
24
+ """Test exclusive cumulative sum."""
25
+ x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
26
+ out = torch.empty_like(x)
27
+ megablocks.exclusive_cumsum(x, 0, out)
28
+ expected = torch.tensor([0, 1, 3, 6], dtype=torch.float32).cuda()
29
+ assert torch.equal(out, expected), f"Expected {expected}, got {out}"
30
+ print("cumsum output:", out)
31
+
32
+ # inclusive_cumsum
33
+ def test_inclusive_cumsum():
34
+ """Test inclusive cumulative sum."""
35
+ x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
36
+ out = torch.empty_like(x)
37
+ megablocks.inclusive_cumsum(x, dim=0, out=out)
38
+ expected = torch.tensor([1, 3, 6, 10], dtype=torch.float32).cuda()
39
+ assert torch.equal(out, expected), f"Expected {expected}, got {out}"
40
+
41
+ # histogram
42
+ def test_histogram():
43
+ """Test histogram operation."""
44
+ x = torch.tensor([0, 1, 1, 2, 2, 2], dtype=torch.int16).cuda()
45
+ num_bins = 3
46
+ hist = megablocks.histogram(x, num_bins)
47
+ expected_hist = torch.tensor([1, 2, 3], dtype=torch.int32).cuda()
48
+ assert torch.equal(hist, expected_hist), f"Expected {expected_hist}, got {hist}"
torch-ext/megablocks/__init__.py CHANGED
@@ -24,7 +24,9 @@ def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tens
24
  Returns:
25
  The output tensor
26
  """
27
- return ops.exclusive_cumsum(x, dim, out)
 
 
28
 
29
 
30
  def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
@@ -39,7 +41,9 @@ def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tens
39
  Returns:
40
  The output tensor
41
  """
42
- return ops.inclusive_cumsum(x, dim, out)
 
 
43
 
44
 
45
  def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
 
24
  Returns:
25
  The output tensor
26
  """
27
+ result = ops.exclusive_cumsum(x, dim)
28
+ out.copy_(result)
29
+ return out
30
 
31
 
32
  def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
 
41
  Returns:
42
  The output tensor
43
  """
44
+ result = ops.inclusive_cumsum(x, dim)
45
+ out.copy_(result)
46
+ return out
47
 
48
 
49
  def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
torch-ext/megablocks/ops/cumsum.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  # instructions for building the c++ operations.
12
  try:
13
  # import megablocks_ops as ops # type: ignore
14
- import megablocks._ops as ops # type: ignore
15
  except ModuleNotFoundError as e:
16
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
17
 
 
11
  # instructions for building the c++ operations.
12
  try:
13
  # import megablocks_ops as ops # type: ignore
14
+ from megablocks._ops import ops # type: ignore
15
  except ModuleNotFoundError as e:
16
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
17
 
torch-ext/megablocks/ops/histogram.py CHANGED
@@ -10,7 +10,7 @@ import torch
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
- import megablocks._ops as ops # type: ignore
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
 
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
+ from megablocks._ops import ops # type: ignore
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
torch-ext/megablocks/ops/replicate.py CHANGED
@@ -10,8 +10,7 @@ import torch
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
- # import megablocks_ops as ops # type: ignore
14
- import megablocks._ops as ops # type: ignore
15
  except ModuleNotFoundError as e:
16
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
17
 
 
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
+ from megablocks._ops import ops # type: ignore
 
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
torch-ext/megablocks/ops/sort.py CHANGED
@@ -10,8 +10,7 @@ import torch
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
- # import megablocks_ops as ops # type: ignore
14
- import megablocks._ops as ops # type: ignore
15
  except ModuleNotFoundError as e:
16
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
17
 
 
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
+ from megablocks._ops import ops # type: ignore
 
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
torch-ext/megablocks/ops/topology.py CHANGED
@@ -10,8 +10,7 @@ import torch
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
- # import megablocks_ops as ops # type: ignore
14
- import megablocks._ops as ops # type: ignore
15
  except ModuleNotFoundError as e:
16
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
17
 
 
10
  # Wrap this in a try-block with better error message and
11
  # instructions for building the c++ operations.
12
  try:
13
+ from megablocks._ops import ops # type: ignore
 
14
  except ModuleNotFoundError as e:
15
  raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
 
torch-ext/torch_binding.cpp CHANGED
@@ -34,22 +34,22 @@ torch::Tensor histogram_wrapper(torch::Tensor x, int64_t num_bins) {
34
  torch::Tensor indices_wrapper(torch::Tensor padded_bins,
35
  int64_t block_size,
36
  int64_t output_block_rows,
37
- int64_t output_block_columns) {
38
- torch::Tensor out = torch::empty({output_block_rows * output_block_columns}, torch::kInt16);
39
  megablocks::indices(padded_bins, block_size, output_block_rows, output_block_columns, out);
40
  return out;
41
  }
42
 
43
 
44
 
45
- // // // Forward pass: replicate values from x according to bin sizes
46
- // // void replicate_forward(torch::Tensor x,
47
- // // torch::Tensor bins,
48
- // // torch::Tensor out);
49
- // tensor::Tensor replicate_forward_wrapper(torch::Tensor x, torch::Tensor bins, torch::Tensor out) {
50
- // megablocks::replicate_forward(x, bins, out);
51
- // return out;
52
- // }
53
 
54
  // // Backward pass: reduce gradients back to bins using segmented reduction
55
  // void replicate_backward(torch::Tensor grad,
@@ -90,11 +90,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
90
  ops.def("histogram(Tensor x, int num_bins) -> Tensor");
91
  ops.impl("histogram", torch::kCUDA, &histogram_wrapper);
92
 
93
- ops.def("indices(Tensor padded_bins, int block_size, int output_block_rows, int output_block_columns) -> Tensor");
94
  ops.impl("indices", torch::kCUDA, &indices_wrapper);
95
 
96
- // ops.def("replicate_forward(Tensor x, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
97
- // ops.impl("replicate_forward", torch::kCUDA, &replicate_forward_wrapper);
98
 
99
  ops.def("replicate_backward(Tensor grad, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
100
  ops.impl("replicate_backward", torch::kCUDA, &replicate_backward_wrapper);
 
34
  torch::Tensor indices_wrapper(torch::Tensor padded_bins,
35
  int64_t block_size,
36
  int64_t output_block_rows,
37
+ int64_t output_block_columns,
38
+ torch::Tensor out) {
39
  megablocks::indices(padded_bins, block_size, output_block_rows, output_block_columns, out);
40
  return out;
41
  }
42
 
43
 
44
 
45
+ // Forward pass: replicate values from x according to bin sizes
46
+ // void replicate_forward(torch::Tensor x,
47
+ // torch::Tensor bins,
48
+ // torch::Tensor out);
49
+ torch::Tensor replicate_forward_wrapper(torch::Tensor x, torch::Tensor bins, torch::Tensor out) {
50
+ megablocks::replicate_forward(x, bins, out);
51
+ return out;
52
+ }
53
 
54
  // // Backward pass: reduce gradients back to bins using segmented reduction
55
  // void replicate_backward(torch::Tensor grad,
 
90
  ops.def("histogram(Tensor x, int num_bins) -> Tensor");
91
  ops.impl("histogram", torch::kCUDA, &histogram_wrapper);
92
 
93
+ ops.def("indices(Tensor padded_bins, int block_size, int output_block_rows, int output_block_columns, Tensor(a!) out) -> Tensor(a!)");
94
  ops.impl("indices", torch::kCUDA, &indices_wrapper);
95
 
96
+ ops.def("replicate_forward(Tensor x, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
97
+ ops.impl("replicate_forward", torch::kCUDA, &replicate_forward_wrapper);
98
 
99
  ops.def("replicate_backward(Tensor grad, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
100
  ops.impl("replicate_backward", torch::kCUDA, &replicate_backward_wrapper);