drbh
commited on
Commit
·
9c4ca75
1
Parent(s):
a4f6452
feat: validate build with original test suite
Browse files- README.md +36 -0
- build.toml +24 -2
- tests/conftest.py +110 -0
- tests/fixtures/autouse.py +107 -0
- tests/fixtures/fixtures.py +13 -0
- tests/layers/architectures.py +53 -0
- tests/layers/moe_test.py +199 -0
- tests/ops/binned_gather_test.py +71 -0
- tests/ops/binned_scatter_test.py +87 -0
- tests/ops/cumsum_test.py +44 -0
- tests/ops/histogram_test.py +82 -0
- tests/ops/padded_gather_test.py +94 -0
- tests/ops/padded_scatter_test.py +155 -0
- tests/ops/replicate_test.py +108 -0
- tests/ops/sort_test.py +65 -0
- tests/ops/topology_test.py +81 -0
- tests/test_mb_moe.py +42 -0
- torch-ext/megablocks/__init__.py +6 -2
- torch-ext/megablocks/ops/cumsum.py +1 -1
- torch-ext/megablocks/ops/histogram.py +1 -1
- torch-ext/megablocks/ops/replicate.py +1 -2
- torch-ext/megablocks/ops/sort.py +1 -2
- torch-ext/megablocks/ops/topology.py +1 -2
- torch-ext/torch_binding.cpp +13 -13
README.md
CHANGED
@@ -4,3 +4,39 @@ tags:
|
|
4 |
- kernel
|
5 |
---
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
- kernel
|
5 |
---
|
6 |
|
7 |
+
|
8 |
+
|
9 |
+
```bash
|
10 |
+
nix develop --show-trace -i -L .#test --command python -m pytest -s tests
|
11 |
+
```
|
12 |
+
|
13 |
+
expected output:
|
14 |
+
|
15 |
+
```
|
16 |
+
============== test session starts ===============
|
17 |
+
platform linux -- Python 3.12.10, pytest-8.3.5, pluggy-1.5.0
|
18 |
+
rootdir: /home/ubuntu/Projects/megablocks-moe
|
19 |
+
plugins: hypothesis-6.130.12
|
20 |
+
collecting 43 items world_size=1
|
21 |
+
collected 387 items
|
22 |
+
|
23 |
+
tests/layers/moe_test.py ...........................................
|
24 |
+
tests/ops/binned_gather_test.py .....................
|
25 |
+
tests/ops/binned_scatter_test.py .....................
|
26 |
+
tests/ops/cumsum_test.py ................................
|
27 |
+
tests/ops/histogram_test.py ......................................................
|
28 |
+
tests/ops/padded_gather_test.py ......................................
|
29 |
+
tests/ops/padded_scatter_test.py ......................................................
|
30 |
+
tests/ops/replicate_test.py ..................................................................................
|
31 |
+
tests/ops/sort_test.py ..................
|
32 |
+
tests/ops/topology_test.py ....................
|
33 |
+
tests/test_mb_moe.py megablocks_moe module imported successfully.
|
34 |
+
Available functions: ['Arguments', 'MLP', 'MoE', 'ParallelDroplessMLP', 'ParallelMLP', 'SparseGLU', 'SparseMLP', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_megablocks_a4f6452_dirty', '_ops', 'argsort', 'backend', 'cumsum', 'dMoE', 'exclusive_cumsum', 'get_load_balancing_loss', 'grouped_gemm_util', 'histogram', 'inclusive_cumsum', 'indices', 'layers', 'ops', 'replicate_backward', 'replicate_forward', 'sort', 'torch']
|
35 |
+
.cumsum output: tensor([0, 1, 3, 6], device='cuda:0', dtype=torch.int16)
|
36 |
+
...
|
37 |
+
|
38 |
+
================ warnings summary ================
|
39 |
+
...
|
40 |
+
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
41 |
+
======= 387 passed, 18 warnings in 54.63s ========
|
42 |
+
```
|
build.toml
CHANGED
@@ -10,6 +10,20 @@ src = [
|
|
10 |
|
11 |
[kernel.megablocks]
|
12 |
backend = "cuda"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
src = [
|
14 |
"csrc/new_cumsum.h",
|
15 |
"csrc/new_cumsum.cu",
|
@@ -22,9 +36,17 @@ src = [
|
|
22 |
"csrc/new_sort.h",
|
23 |
"csrc/new_sort.cu",
|
24 |
]
|
25 |
-
depends = [ "torch", "cutlass_3_8" ]
|
26 |
|
27 |
[test]
|
28 |
python-git-packages = [
|
29 |
-
{ url = "https://github.com/stanford-futuredata/stk.git", rev = "7363137", sha256 = "0m6g5l9nlwaiwybg5j8dhnz159wdpabdnkzapnn3dsifxrsb59vz" }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
]
|
|
|
10 |
|
11 |
[kernel.megablocks]
|
12 |
backend = "cuda"
|
13 |
+
cuda-capabilities = [
|
14 |
+
"7.0",
|
15 |
+
"7.2",
|
16 |
+
"7.5",
|
17 |
+
"8.0",
|
18 |
+
"8.6",
|
19 |
+
"8.7",
|
20 |
+
"8.9",
|
21 |
+
"9.0",
|
22 |
+
"10.0",
|
23 |
+
"10.1",
|
24 |
+
"12.0",
|
25 |
+
]
|
26 |
+
depends = ["torch", "cutlass_3_8"]
|
27 |
src = [
|
28 |
"csrc/new_cumsum.h",
|
29 |
"csrc/new_cumsum.cu",
|
|
|
36 |
"csrc/new_sort.h",
|
37 |
"csrc/new_sort.cu",
|
38 |
]
|
|
|
39 |
|
40 |
[test]
|
41 |
python-git-packages = [
|
42 |
+
{ url = "https://github.com/stanford-futuredata/stk.git", rev = "7363137", sha256 = "0m6g5l9nlwaiwybg5j8dhnz159wdpabdnkzapnn3dsifxrsb59vz" },
|
43 |
+
{ url = "https://github.com/mosaicml/composer.git", rev = "v0.9.0", sha256 = "ekJ5nE6JwYY6Ld9kIk72R/a3iI943Gd5yvAkBHQs5aI=" },
|
44 |
+
# { url = "https://github.com/tgale96/grouped_gemm.git", rev = "v0.3.0", sha256 = "sha256-fS6MuDj6yQ00CSzFrmAmM20/ccvtLJ1MFjfeqdwuPl8=" }
|
45 |
+
]
|
46 |
+
python-packages = [
|
47 |
+
"tqdm",
|
48 |
+
"py-cpuinfo",
|
49 |
+
"importlib-metadata",
|
50 |
+
"torchmetrics",
|
51 |
+
# "yahp"
|
52 |
]
|
tests/conftest.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import os
|
5 |
+
from typing import List, Optional
|
6 |
+
|
7 |
+
import pytest
|
8 |
+
# from composer.utils import reproducibility
|
9 |
+
|
10 |
+
# Allowed options for pytest.mark.world_size()
|
11 |
+
WORLD_SIZE_OPTIONS = (1, 2)
|
12 |
+
|
13 |
+
# Enforce deterministic mode before any tests start.
|
14 |
+
# reproducibility.configure_deterministic_mode()
|
15 |
+
|
16 |
+
# TODO: allow plugind when deps resolved
|
17 |
+
|
18 |
+
# Add the path of any pytest fixture files you want to make global
|
19 |
+
pytest_plugins = [
|
20 |
+
# 'tests.fixtures.autouse',
|
21 |
+
'tests.fixtures.fixtures',
|
22 |
+
]
|
23 |
+
|
24 |
+
|
25 |
+
def _get_world_size(item: pytest.Item):
|
26 |
+
"""Returns the world_size of a test, defaults to 1."""
|
27 |
+
_default = pytest.mark.world_size(1).mark
|
28 |
+
return item.get_closest_marker('world_size', default=_default).args[0]
|
29 |
+
|
30 |
+
|
31 |
+
def _get_option(
|
32 |
+
config: pytest.Config,
|
33 |
+
name: str,
|
34 |
+
default: Optional[str] = None,
|
35 |
+
) -> str: # type: ignore
|
36 |
+
val = config.getoption(name)
|
37 |
+
if val is not None:
|
38 |
+
assert isinstance(val, str)
|
39 |
+
return val
|
40 |
+
val = config.getini(name)
|
41 |
+
if val == []:
|
42 |
+
val = None
|
43 |
+
if val is None:
|
44 |
+
if default is None:
|
45 |
+
pytest.fail(f'Config option {name} is not specified but is required',)
|
46 |
+
val = default
|
47 |
+
assert isinstance(val, str)
|
48 |
+
return val
|
49 |
+
|
50 |
+
|
51 |
+
def _add_option(
|
52 |
+
parser: pytest.Parser,
|
53 |
+
name: str,
|
54 |
+
help: str,
|
55 |
+
choices: Optional[list[str]] = None,
|
56 |
+
):
|
57 |
+
parser.addoption(
|
58 |
+
f'--{name}',
|
59 |
+
default=None,
|
60 |
+
type=str,
|
61 |
+
choices=choices,
|
62 |
+
help=help,
|
63 |
+
)
|
64 |
+
parser.addini(
|
65 |
+
name=name,
|
66 |
+
help=help,
|
67 |
+
type='string',
|
68 |
+
default=None,
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
def pytest_collection_modifyitems(
|
73 |
+
config: pytest.Config,
|
74 |
+
items: List[pytest.Item],
|
75 |
+
) -> None:
|
76 |
+
"""Filter tests by world_size (for multi-GPU tests)"""
|
77 |
+
world_size = int(os.environ.get('WORLD_SIZE', '1'))
|
78 |
+
print(f'world_size={world_size}')
|
79 |
+
|
80 |
+
conditions = [
|
81 |
+
lambda item: _get_world_size(item) == world_size,
|
82 |
+
]
|
83 |
+
|
84 |
+
# keep items that satisfy all conditions
|
85 |
+
remaining = []
|
86 |
+
deselected = []
|
87 |
+
for item in items:
|
88 |
+
if all(condition(item) for condition in conditions):
|
89 |
+
remaining.append(item)
|
90 |
+
else:
|
91 |
+
deselected.append(item)
|
92 |
+
|
93 |
+
if deselected:
|
94 |
+
config.hook.pytest_deselected(items=deselected)
|
95 |
+
items[:] = remaining
|
96 |
+
|
97 |
+
|
98 |
+
def pytest_addoption(parser: pytest.Parser) -> None:
|
99 |
+
_add_option(
|
100 |
+
parser,
|
101 |
+
'seed',
|
102 |
+
help="""\
|
103 |
+
Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked
|
104 |
+
before each test.""",
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
|
109 |
+
if exitstatus == 5:
|
110 |
+
session.exitstatus = 0 # Ignore no-test-ran errors
|
tests/fixtures/autouse.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import gc
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
|
8 |
+
import composer
|
9 |
+
import pytest
|
10 |
+
import torch
|
11 |
+
from composer.devices import DeviceCPU, DeviceGPU
|
12 |
+
from composer.utils import dist, reproducibility
|
13 |
+
|
14 |
+
|
15 |
+
@pytest.fixture(autouse=True)
|
16 |
+
def clear_cuda_cache(request: pytest.FixtureRequest):
|
17 |
+
"""Clear memory between GPU tests."""
|
18 |
+
marker = request.node.get_closest_marker('gpu')
|
19 |
+
if marker is not None and torch.cuda.is_available():
|
20 |
+
torch.cuda.empty_cache()
|
21 |
+
gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests
|
22 |
+
|
23 |
+
|
24 |
+
@pytest.fixture(autouse=True)
|
25 |
+
def reset_mlflow_tracking_dir():
|
26 |
+
"""Reset MLFlow tracking dir so it doesn't persist across tests."""
|
27 |
+
try:
|
28 |
+
import mlflow
|
29 |
+
mlflow.set_tracking_uri(None) # type: ignore
|
30 |
+
except ModuleNotFoundError:
|
31 |
+
# MLFlow not installed
|
32 |
+
pass
|
33 |
+
|
34 |
+
|
35 |
+
@pytest.fixture(scope='session')
|
36 |
+
def cleanup_dist():
|
37 |
+
"""Ensure all dist tests clean up resources properly."""
|
38 |
+
yield
|
39 |
+
# Avoid race condition where a test is still writing to a file on one rank
|
40 |
+
# while the file system is being torn down on another rank.
|
41 |
+
dist.barrier()
|
42 |
+
|
43 |
+
|
44 |
+
@pytest.fixture(autouse=True, scope='session')
|
45 |
+
def configure_dist(request: pytest.FixtureRequest):
|
46 |
+
# Configure dist globally when the world size is greater than 1,
|
47 |
+
# so individual tests that do not use the trainer
|
48 |
+
# do not need to worry about manually configuring dist.
|
49 |
+
|
50 |
+
if dist.get_world_size() == 1:
|
51 |
+
return
|
52 |
+
|
53 |
+
device = None
|
54 |
+
|
55 |
+
for item in request.session.items:
|
56 |
+
device = DeviceCPU() if item.get_closest_marker('gpu') is None else DeviceGPU()
|
57 |
+
break
|
58 |
+
|
59 |
+
assert device is not None
|
60 |
+
|
61 |
+
if not dist.is_initialized():
|
62 |
+
dist.initialize_dist(device, timeout=300.0)
|
63 |
+
# Hold PyTest until all ranks have reached this barrier. Ensure that no rank starts
|
64 |
+
# any test before other ranks are ready to start it, which could be a cause of random timeouts
|
65 |
+
# (e.g. rank 1 starts the next test while rank 0 is finishing up the previous test).
|
66 |
+
dist.barrier()
|
67 |
+
|
68 |
+
|
69 |
+
@pytest.fixture(autouse=True)
|
70 |
+
def set_log_levels():
|
71 |
+
"""Ensures all log levels are set to DEBUG."""
|
72 |
+
logging.basicConfig()
|
73 |
+
logging.getLogger(composer.__name__).setLevel(logging.DEBUG)
|
74 |
+
|
75 |
+
|
76 |
+
@pytest.fixture(autouse=True)
|
77 |
+
def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch):
|
78 |
+
"""Monkeypatch reproducibility.
|
79 |
+
|
80 |
+
Make get_random_seed to always return the rank zero seed, and set the random seed before each test to the rank local
|
81 |
+
seed.
|
82 |
+
"""
|
83 |
+
monkeypatch.setattr(
|
84 |
+
reproducibility,
|
85 |
+
'get_random_seed',
|
86 |
+
lambda: rank_zero_seed,
|
87 |
+
)
|
88 |
+
reproducibility.seed_all(rank_zero_seed + dist.get_global_rank())
|
89 |
+
|
90 |
+
|
91 |
+
@pytest.fixture(autouse=True)
|
92 |
+
def remove_run_name_env_var():
|
93 |
+
# Remove environment variables for run names in unit tests
|
94 |
+
composer_run_name = os.environ.get('COMPOSER_RUN_NAME')
|
95 |
+
run_name = os.environ.get('RUN_NAME')
|
96 |
+
|
97 |
+
if 'COMPOSER_RUN_NAME' in os.environ:
|
98 |
+
del os.environ['COMPOSER_RUN_NAME']
|
99 |
+
if 'RUN_NAME' in os.environ:
|
100 |
+
del os.environ['RUN_NAME']
|
101 |
+
|
102 |
+
yield
|
103 |
+
|
104 |
+
if composer_run_name is not None:
|
105 |
+
os.environ['COMPOSER_RUN_NAME'] = composer_run_name
|
106 |
+
if run_name is not None:
|
107 |
+
os.environ['RUN_NAME'] = run_name
|
tests/fixtures/fixtures.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
|
6 |
+
from tests.conftest import _get_option
|
7 |
+
|
8 |
+
|
9 |
+
@pytest.fixture
|
10 |
+
def rank_zero_seed(pytestconfig: pytest.Config) -> int:
|
11 |
+
"""Read the rank_zero_seed from the CLI option."""
|
12 |
+
seed = _get_option(pytestconfig, 'seed', default='0')
|
13 |
+
return int(seed)
|
tests/layers/architectures.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from megablocks.layers.arguments import Arguments
|
8 |
+
|
9 |
+
|
10 |
+
class FFN(torch.nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, args: Arguments):
|
13 |
+
super().__init__()
|
14 |
+
self.w1 = torch.nn.Parameter(
|
15 |
+
torch.empty(
|
16 |
+
args.hidden_size,
|
17 |
+
args.ffn_hidden_size,
|
18 |
+
device=args.device,
|
19 |
+
dtype=torch.float16 if args.fp16 else torch.float32,
|
20 |
+
),
|
21 |
+
)
|
22 |
+
self.w2 = torch.nn.Parameter(
|
23 |
+
torch.empty(
|
24 |
+
args.ffn_hidden_size,
|
25 |
+
args.hidden_size,
|
26 |
+
device=args.device,
|
27 |
+
dtype=torch.float16 if args.fp16 else torch.float32,
|
28 |
+
),
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
return torch.matmul(
|
33 |
+
F.gelu(torch.matmul(x, self.w1), approximate='tanh'),
|
34 |
+
self.w2,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
class GLU(FFN):
|
39 |
+
|
40 |
+
def __init__(self, args: Arguments):
|
41 |
+
super().__init__(args)
|
42 |
+
self.v1 = torch.nn.Parameter(
|
43 |
+
torch.empty(
|
44 |
+
args.hidden_size,
|
45 |
+
args.ffn_hidden_size,
|
46 |
+
device=args.device,
|
47 |
+
dtype=torch.float16 if args.fp16 else torch.float32,
|
48 |
+
),
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x1 = F.gelu(torch.matmul(x, self.w1), approximate='tanh') * torch.matmul(x, self.v1)
|
53 |
+
return torch.matmul(x1, self.w2)
|
tests/layers/moe_test.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import pytest
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from megablocks.layers.arguments import Arguments
|
10 |
+
from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss
|
11 |
+
from megablocks.layers.router import batched_router_zloss, clear_router_zloss
|
12 |
+
from tests.layers.architectures import FFN
|
13 |
+
|
14 |
+
_FORWARD_TESTS = (
|
15 |
+
(16, 1024, 512, 1, 1),
|
16 |
+
(16, 1024, 512, 2, 1),
|
17 |
+
(16, 1024, 512, 4, 1),
|
18 |
+
(16, 1024, 512, 8, 1),
|
19 |
+
(8, 2048, 512, 1, 1),
|
20 |
+
(8, 2048, 512, 2, 1),
|
21 |
+
(8, 2048, 512, 4, 1),
|
22 |
+
(16, 1024, 512, 2, 2),
|
23 |
+
(16, 1024, 512, 4, 2),
|
24 |
+
(16, 1024, 512, 4, 4),
|
25 |
+
(16, 1024, 512, 8, 2),
|
26 |
+
(16, 1024, 512, 8, 4),
|
27 |
+
(16, 1024, 512, 8, 8),
|
28 |
+
)
|
29 |
+
|
30 |
+
_DENSE_TESTS = (
|
31 |
+
(16, 1024, 512),
|
32 |
+
(8, 2048, 512),
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def construct_moe(
|
37 |
+
hidden_size: int,
|
38 |
+
ffn_hidden_size: int,
|
39 |
+
moe_num_experts: int = 1,
|
40 |
+
moe_capacity_factor: int = 1,
|
41 |
+
moe_top_k: int = 1,
|
42 |
+
moe_zloss_weight: float = 0,
|
43 |
+
):
|
44 |
+
# All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
|
45 |
+
# TODO: Remove this once sparse is supported with triton >=3.2.0
|
46 |
+
try:
|
47 |
+
import triton
|
48 |
+
if triton.__version__ >= '3.2.0':
|
49 |
+
pytest.skip('Sparse MLP is not supported with triton >=3.2.0')
|
50 |
+
except ImportError:
|
51 |
+
pass
|
52 |
+
|
53 |
+
init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
|
54 |
+
args = Arguments(
|
55 |
+
hidden_size=hidden_size,
|
56 |
+
ffn_hidden_size=ffn_hidden_size,
|
57 |
+
moe_num_experts=moe_num_experts,
|
58 |
+
moe_capacity_factor=moe_capacity_factor,
|
59 |
+
moe_top_k=moe_top_k,
|
60 |
+
init_method=init_method,
|
61 |
+
moe_zloss_weight=moe_zloss_weight,
|
62 |
+
)
|
63 |
+
|
64 |
+
mlp = FFN(args)
|
65 |
+
moe_mlp = MoE(args)
|
66 |
+
|
67 |
+
mlp.cuda(torch.cuda.current_device()).half()
|
68 |
+
moe_mlp.cuda(torch.cuda.current_device()).half()
|
69 |
+
|
70 |
+
# Set the baseline parameters to match exactly.
|
71 |
+
if moe_num_experts == 1:
|
72 |
+
with torch.no_grad():
|
73 |
+
mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze())
|
74 |
+
mlp.w2.copy_(moe_mlp.experts.mlp.w2.squeeze())
|
75 |
+
return args, mlp, moe_mlp
|
76 |
+
|
77 |
+
|
78 |
+
@pytest.mark.gpu
|
79 |
+
@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
|
80 |
+
def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int):
|
81 |
+
x = torch.randn(sl, bs, hs).half().cuda()
|
82 |
+
|
83 |
+
_, _, layer = construct_moe(
|
84 |
+
hidden_size=hs,
|
85 |
+
ffn_hidden_size=hs * 2,
|
86 |
+
moe_num_experts=num_experts,
|
87 |
+
moe_top_k=top_k,
|
88 |
+
)
|
89 |
+
|
90 |
+
out, _ = layer(x)
|
91 |
+
assert out.shape == x.shape
|
92 |
+
clear_load_balancing_loss()
|
93 |
+
|
94 |
+
|
95 |
+
@pytest.mark.gpu
|
96 |
+
@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
|
97 |
+
def test_moe_forward_backward(
|
98 |
+
bs: int,
|
99 |
+
sl: int,
|
100 |
+
hs: int,
|
101 |
+
num_experts: int,
|
102 |
+
top_k: int,
|
103 |
+
):
|
104 |
+
x = torch.randn(sl, bs, hs).half().cuda()
|
105 |
+
x.requires_grad_(True)
|
106 |
+
|
107 |
+
args, _, layer = construct_moe(
|
108 |
+
hidden_size=hs,
|
109 |
+
ffn_hidden_size=hs * 2,
|
110 |
+
moe_num_experts=num_experts,
|
111 |
+
moe_top_k=top_k,
|
112 |
+
)
|
113 |
+
|
114 |
+
out, _ = layer(x)
|
115 |
+
assert out.shape == x.shape
|
116 |
+
|
117 |
+
loss = out.sum() + batched_load_balancing_loss(args)
|
118 |
+
loss.backward()
|
119 |
+
layer.zero_grad(set_to_none=True)
|
120 |
+
x.grad = None
|
121 |
+
clear_load_balancing_loss()
|
122 |
+
|
123 |
+
|
124 |
+
@pytest.mark.gpu
|
125 |
+
@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
|
126 |
+
def test_moe_forward_backward_with_zloss(
|
127 |
+
bs: int,
|
128 |
+
sl: int,
|
129 |
+
hs: int,
|
130 |
+
num_experts: int,
|
131 |
+
top_k: int,
|
132 |
+
):
|
133 |
+
x = torch.randn(sl, bs, hs).half().cuda()
|
134 |
+
x.requires_grad_(True)
|
135 |
+
|
136 |
+
args, _, layer = construct_moe(
|
137 |
+
hidden_size=hs,
|
138 |
+
ffn_hidden_size=hs * 2,
|
139 |
+
moe_num_experts=num_experts,
|
140 |
+
moe_top_k=top_k,
|
141 |
+
moe_zloss_weight=1e-3,
|
142 |
+
)
|
143 |
+
|
144 |
+
out, _ = layer(x)
|
145 |
+
assert out.shape == x.shape
|
146 |
+
|
147 |
+
loss = out.sum() + batched_load_balancing_loss(args)
|
148 |
+
loss.backward()
|
149 |
+
layer.zero_grad(set_to_none=True)
|
150 |
+
x.grad = None
|
151 |
+
clear_load_balancing_loss()
|
152 |
+
clear_router_zloss()
|
153 |
+
|
154 |
+
|
155 |
+
@pytest.mark.gpu
|
156 |
+
@pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
|
157 |
+
def test_moe_forward_vs_dense(bs: int, sl: int, hs: int):
|
158 |
+
x = torch.randn(sl, bs, hs).half().cuda()
|
159 |
+
|
160 |
+
_, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2)
|
161 |
+
|
162 |
+
expected_out = mlp(x)
|
163 |
+
out, _ = moe_mlp(x)
|
164 |
+
assert out.shape == x.shape == expected_out.shape
|
165 |
+
assert torch.allclose(out, expected_out)
|
166 |
+
clear_load_balancing_loss()
|
167 |
+
|
168 |
+
|
169 |
+
@pytest.mark.gpu
|
170 |
+
@pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
|
171 |
+
def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int):
|
172 |
+
x = torch.randn(sl, bs, hs).half().cuda()
|
173 |
+
x.requires_grad_(True)
|
174 |
+
|
175 |
+
_, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2)
|
176 |
+
|
177 |
+
out, _ = moe_mlp(x)
|
178 |
+
loss = out.sum()
|
179 |
+
loss.backward()
|
180 |
+
w1_grad = moe_mlp.experts.mlp.w1.grad.detach().squeeze()
|
181 |
+
w2_grad = moe_mlp.experts.mlp.w2.grad.detach().squeeze()
|
182 |
+
moe_mlp.zero_grad(set_to_none=True)
|
183 |
+
x.grad = None
|
184 |
+
clear_load_balancing_loss()
|
185 |
+
|
186 |
+
expected_out = mlp(x)
|
187 |
+
expected_loss = expected_out.sum()
|
188 |
+
expected_loss.backward()
|
189 |
+
expected_w1_grad = mlp.w1.grad.detach()
|
190 |
+
expected_w2_grad = mlp.w2.grad.detach()
|
191 |
+
mlp.zero_grad(set_to_none=True)
|
192 |
+
x.grad = None
|
193 |
+
|
194 |
+
# Verify the gradients match.
|
195 |
+
assert w1_grad.shape == expected_w1_grad.shape
|
196 |
+
assert w2_grad.shape == expected_w2_grad.shape
|
197 |
+
assert torch.allclose(w1_grad, expected_w1_grad)
|
198 |
+
assert torch.allclose(w2_grad, expected_w2_grad)
|
199 |
+
clear_load_balancing_loss()
|
tests/ops/binned_gather_test.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from megablocks import ops
|
9 |
+
|
10 |
+
BINNED_GATHER_TESTS = (
|
11 |
+
(4, 2, 2, 1),
|
12 |
+
(4, 2, 2, 2),
|
13 |
+
(4, 2, 2, 4),
|
14 |
+
(1024, 1536, 4, 1),
|
15 |
+
(1024, 1536, 4, 2),
|
16 |
+
(1024, 1536, 4, 4),
|
17 |
+
(1024, 1536, 64, 1),
|
18 |
+
(1024, 1536, 64, 2),
|
19 |
+
(1024, 1536, 64, 4),
|
20 |
+
(1024, 1536, 128, 1),
|
21 |
+
(1024, 1536, 128, 2),
|
22 |
+
(1024, 1536, 128, 4),
|
23 |
+
(16384, 768, 4, 1),
|
24 |
+
(16384, 768, 4, 2),
|
25 |
+
(16384, 768, 4, 4),
|
26 |
+
(16384, 768, 64, 1),
|
27 |
+
(16384, 768, 64, 2),
|
28 |
+
(16384, 768, 64, 4),
|
29 |
+
(16384, 768, 128, 1),
|
30 |
+
(16384, 768, 128, 2),
|
31 |
+
(16384, 768, 128, 4),
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
@pytest.mark.gpu
|
36 |
+
@pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), BINNED_GATHER_TESTS)
|
37 |
+
def test_binned_gather(sl: int, hs: int, ne: int, top_k: int):
|
38 |
+
# NOTE: Capacity factor == 1.
|
39 |
+
ec = (sl * top_k) // ne
|
40 |
+
|
41 |
+
# Create the data and indices.
|
42 |
+
x = torch.randn((sl, hs)).cuda().half()
|
43 |
+
|
44 |
+
# Randomly assign tokens to experts.
|
45 |
+
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
|
46 |
+
_, indices = ops.sort(top_expert)
|
47 |
+
bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0)
|
48 |
+
|
49 |
+
def binned_gather(
|
50 |
+
x: torch.Tensor,
|
51 |
+
indices: torch.Tensor,
|
52 |
+
bins: torch.Tensor,
|
53 |
+
ec: int,
|
54 |
+
top_k: int,
|
55 |
+
):
|
56 |
+
x = x.cpu().numpy()
|
57 |
+
indices = indices.cpu().numpy()
|
58 |
+
bins = bins.cpu().numpy()
|
59 |
+
start = 0
|
60 |
+
out = np.zeros((ne, ec, hs))
|
61 |
+
for i in range(ne):
|
62 |
+
end = bins[i]
|
63 |
+
for j in range(min(ec, end - start)):
|
64 |
+
index = indices[start + j] // top_k
|
65 |
+
out[i, j, :] = x[index, :]
|
66 |
+
start = end
|
67 |
+
return torch.from_numpy(out).cuda().half()
|
68 |
+
|
69 |
+
out = ops.binned_gather(x, indices, bins, ec, top_k)
|
70 |
+
expected_out = binned_gather(x, indices, bins, ec, top_k)
|
71 |
+
assert torch.all(torch.eq(out, expected_out))
|
tests/ops/binned_scatter_test.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from megablocks import ops
|
9 |
+
|
10 |
+
_BINNED_SCATTER_TESTS = (
|
11 |
+
(4, 2, 2, 1),
|
12 |
+
(4, 2, 2, 2),
|
13 |
+
(4, 2, 2, 4),
|
14 |
+
(1024, 1536, 4, 1),
|
15 |
+
(1024, 1536, 4, 2),
|
16 |
+
(1024, 1536, 4, 4),
|
17 |
+
(1024, 1536, 64, 1),
|
18 |
+
(1024, 1536, 64, 2),
|
19 |
+
(1024, 1536, 64, 4),
|
20 |
+
(1024, 1536, 128, 1),
|
21 |
+
(1024, 1536, 128, 2),
|
22 |
+
(1024, 1536, 128, 4),
|
23 |
+
(16384, 768, 4, 1),
|
24 |
+
(16384, 768, 4, 2),
|
25 |
+
(16384, 768, 4, 4),
|
26 |
+
(16384, 768, 64, 1),
|
27 |
+
(16384, 768, 64, 2),
|
28 |
+
(16384, 768, 64, 4),
|
29 |
+
(16384, 768, 128, 1),
|
30 |
+
(16384, 768, 128, 2),
|
31 |
+
(16384, 768, 128, 4),
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
@pytest.mark.gpu
|
36 |
+
@pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), _BINNED_SCATTER_TESTS)
|
37 |
+
def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int):
|
38 |
+
# NOTE: Capacity factor == 1.
|
39 |
+
ec = (sl * top_k) // ne
|
40 |
+
|
41 |
+
# Create the data and indices.
|
42 |
+
x = torch.randn((sl, hs)).cuda().half()
|
43 |
+
|
44 |
+
# Randomly assign tokens to experts.
|
45 |
+
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
|
46 |
+
_, indices = ops.sort(top_expert)
|
47 |
+
bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0)
|
48 |
+
|
49 |
+
# Sample weights for the scatter reduce.
|
50 |
+
weights = torch.rand((sl * top_k,)).cuda().half()
|
51 |
+
|
52 |
+
x = ops.binned_gather(x, indices, bins, ec, top_k)
|
53 |
+
|
54 |
+
def binned_scatter(
|
55 |
+
x: torch.Tensor,
|
56 |
+
indices: torch.Tensor,
|
57 |
+
weights: torch.Tensor,
|
58 |
+
bins: torch.Tensor,
|
59 |
+
top_k: int,
|
60 |
+
):
|
61 |
+
x = x.cpu().numpy()
|
62 |
+
indices = indices.cpu().numpy()
|
63 |
+
weights = weights.cpu().numpy()
|
64 |
+
bins = bins.cpu().numpy()
|
65 |
+
start = 0
|
66 |
+
out = np.zeros((sl, hs))
|
67 |
+
for i in range(ne):
|
68 |
+
end = bins[i]
|
69 |
+
for j in range(min(ec, end - start)):
|
70 |
+
index = indices[start + j]
|
71 |
+
scale = weights[index]
|
72 |
+
index //= top_k
|
73 |
+
|
74 |
+
out[index, :] += scale * x[i, j, :]
|
75 |
+
start = end
|
76 |
+
return torch.from_numpy(out).cuda().half()
|
77 |
+
|
78 |
+
out = ops.binned_scatter(x, indices, weights, bins, top_k)
|
79 |
+
expected_out = binned_scatter(x, indices, weights, bins, top_k)
|
80 |
+
|
81 |
+
# NOTE: We need to check approximate equality because the
|
82 |
+
# scatter reduce uses atomics.
|
83 |
+
assert np.testing.assert_allclose(
|
84 |
+
out.cpu(),
|
85 |
+
expected_out.cpu(),
|
86 |
+
rtol=5e-3,
|
87 |
+
) is None
|
tests/ops/cumsum_test.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from megablocks import ops
|
8 |
+
|
9 |
+
CUMSUM_TESTS = (
|
10 |
+
(1, 32),
|
11 |
+
(2, 32),
|
12 |
+
(2, 1024),
|
13 |
+
(4, 1024),
|
14 |
+
(8, 1024),
|
15 |
+
(16, 1024),
|
16 |
+
(32, 1024),
|
17 |
+
(64, 1024),
|
18 |
+
(128, 1024),
|
19 |
+
(2, 16384),
|
20 |
+
(4, 16384),
|
21 |
+
(8, 16384),
|
22 |
+
(16, 16384),
|
23 |
+
(32, 16384),
|
24 |
+
(64, 16384),
|
25 |
+
(128, 16384),
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
@pytest.mark.gpu
|
30 |
+
@pytest.mark.parametrize(('n', 'm'), CUMSUM_TESTS)
|
31 |
+
def test_exclusive_cumsum(n: int, m: int):
|
32 |
+
x = torch.randint(0, 2, (n, m)).long().cuda()
|
33 |
+
out = ops.exclusive_cumsum(x, 1) * x
|
34 |
+
expected_out = (torch.cumsum(x, dim=1) - 1) * x
|
35 |
+
assert torch.all(torch.eq(out, expected_out))
|
36 |
+
|
37 |
+
|
38 |
+
@pytest.mark.gpu
|
39 |
+
@pytest.mark.parametrize(('n', 'm'), CUMSUM_TESTS)
|
40 |
+
def test_inclusive_cumsum(n: int, m: int):
|
41 |
+
x = torch.randint(0, 2, (n, m)).long().cuda()
|
42 |
+
out = ops.inclusive_cumsum(x, 1)
|
43 |
+
expected_out = torch.cumsum(x, dim=1)
|
44 |
+
assert torch.all(torch.eq(out, expected_out))
|
tests/ops/histogram_test.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from megablocks import ops
|
8 |
+
|
9 |
+
_HISTOGRAM_TESTS = (
|
10 |
+
(1, 32, torch.int16, 128),
|
11 |
+
(1, 1024, torch.int16, 128),
|
12 |
+
(1, 16384, torch.int16, 128),
|
13 |
+
(1, 32, torch.int32, 128),
|
14 |
+
(1, 1024, torch.int32, 128),
|
15 |
+
(1, 16384, torch.int32, 128),
|
16 |
+
(1, 32, torch.int64, 128),
|
17 |
+
(1, 1024, torch.int64, 128),
|
18 |
+
(1, 16384, torch.int64, 128),
|
19 |
+
(1, 32, torch.int16, 1024),
|
20 |
+
(1, 1024, torch.int16, 1024),
|
21 |
+
(1, 16384, torch.int16, 1024),
|
22 |
+
(1, 32, torch.int32, 1024),
|
23 |
+
(1, 1024, torch.int32, 1024),
|
24 |
+
(1, 16384, torch.int32, 1024),
|
25 |
+
(1, 32, torch.int64, 1024),
|
26 |
+
(1, 1024, torch.int64, 1024),
|
27 |
+
(1, 16384, torch.int64, 1024),
|
28 |
+
(2, 32, torch.int16, 128),
|
29 |
+
(2, 1024, torch.int16, 128),
|
30 |
+
(2, 16384, torch.int16, 128),
|
31 |
+
(2, 32, torch.int32, 128),
|
32 |
+
(2, 1024, torch.int32, 128),
|
33 |
+
(2, 16384, torch.int32, 128),
|
34 |
+
(2, 32, torch.int64, 128),
|
35 |
+
(2, 1024, torch.int64, 128),
|
36 |
+
(2, 16384, torch.int64, 128),
|
37 |
+
(2, 32, torch.int16, 1024),
|
38 |
+
(2, 1024, torch.int16, 1024),
|
39 |
+
(2, 16384, torch.int16, 1024),
|
40 |
+
(2, 32, torch.int32, 1024),
|
41 |
+
(2, 1024, torch.int32, 1024),
|
42 |
+
(2, 16384, torch.int32, 1024),
|
43 |
+
(2, 32, torch.int64, 1024),
|
44 |
+
(2, 1024, torch.int64, 1024),
|
45 |
+
(2, 16384, torch.int64, 1024),
|
46 |
+
(8, 32, torch.int16, 128),
|
47 |
+
(8, 1024, torch.int16, 128),
|
48 |
+
(8, 16384, torch.int16, 128),
|
49 |
+
(8, 32, torch.int32, 128),
|
50 |
+
(8, 1024, torch.int32, 128),
|
51 |
+
(8, 16384, torch.int32, 128),
|
52 |
+
(8, 32, torch.int64, 128),
|
53 |
+
(8, 1024, torch.int64, 128),
|
54 |
+
(8, 16384, torch.int64, 128),
|
55 |
+
(8, 32, torch.int16, 1024),
|
56 |
+
(8, 1024, torch.int16, 1024),
|
57 |
+
(8, 16384, torch.int16, 1024),
|
58 |
+
(8, 32, torch.int32, 1024),
|
59 |
+
(8, 1024, torch.int32, 1024),
|
60 |
+
(8, 16384, torch.int32, 1024),
|
61 |
+
(8, 32, torch.int64, 1024),
|
62 |
+
(8, 1024, torch.int64, 1024),
|
63 |
+
(8, 16384, torch.int64, 1024),
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
# Override the seed_all fixture in autouse.py because
|
68 |
+
# _histc_cuda does not have a deterministic implementation
|
69 |
+
@pytest.fixture()
|
70 |
+
def seed_all():
|
71 |
+
torch.use_deterministic_algorithms(False)
|
72 |
+
return
|
73 |
+
|
74 |
+
|
75 |
+
@pytest.mark.gpu
|
76 |
+
@pytest.mark.parametrize(('m', 'n', 'dtype', 'max_val'), _HISTOGRAM_TESTS)
|
77 |
+
def test_histogram(m: int, n: int, dtype: torch.dtype, max_val: int):
|
78 |
+
x = torch.randint(0, max_val, (m, n)).cuda().to(dtype)
|
79 |
+
|
80 |
+
out = ops.histogram(x, max_val)
|
81 |
+
expected_out = torch.stack([torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)])
|
82 |
+
assert torch.all(torch.eq(out, expected_out))
|
tests/ops/padded_gather_test.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from megablocks import ops
|
9 |
+
|
10 |
+
PADDED_GATHER_TESTS = (
|
11 |
+
(4, 2, 2, 1),
|
12 |
+
(4, 2, 2, 2),
|
13 |
+
(1024, 1, 4, 1),
|
14 |
+
(1024, 1, 4, 2),
|
15 |
+
(1024, 1, 4, 4),
|
16 |
+
(1024, 1, 64, 1),
|
17 |
+
(1024, 1, 64, 2),
|
18 |
+
(1024, 1, 64, 4),
|
19 |
+
(1024, 1, 128, 1),
|
20 |
+
(1024, 1, 128, 2),
|
21 |
+
(1024, 1, 128, 4),
|
22 |
+
(1024, 1536, 4, 1),
|
23 |
+
(1024, 1536, 4, 2),
|
24 |
+
(1024, 1536, 4, 4),
|
25 |
+
(1024, 1536, 64, 1),
|
26 |
+
(1024, 1536, 64, 2),
|
27 |
+
(1024, 1536, 64, 4),
|
28 |
+
(1024, 1536, 128, 1),
|
29 |
+
(1024, 1536, 128, 2),
|
30 |
+
(1024, 1536, 128, 4),
|
31 |
+
(16384, 768, 4, 1),
|
32 |
+
(16384, 768, 4, 2),
|
33 |
+
(16384, 768, 4, 4),
|
34 |
+
(16384, 768, 64, 1),
|
35 |
+
(16384, 768, 64, 2),
|
36 |
+
(16384, 768, 64, 4),
|
37 |
+
(16384, 768, 128, 1),
|
38 |
+
(16384, 768, 128, 2),
|
39 |
+
(16384, 768, 128, 4),
|
40 |
+
(16384, 1, 4, 1),
|
41 |
+
(16384, 1, 4, 2),
|
42 |
+
(16384, 1, 4, 4),
|
43 |
+
(16384, 1, 64, 1),
|
44 |
+
(16384, 1, 64, 2),
|
45 |
+
(16384, 1, 64, 4),
|
46 |
+
(16384, 1, 128, 1),
|
47 |
+
(16384, 1, 128, 2),
|
48 |
+
(16384, 1, 128, 4),
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
@pytest.mark.gpu
|
53 |
+
@pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), PADDED_GATHER_TESTS)
|
54 |
+
def testPaddedGather(sl: int, hs: int, ne: int, top_k: int):
|
55 |
+
# Create the data and indices.
|
56 |
+
x = torch.randn((sl, hs)).cuda().half()
|
57 |
+
|
58 |
+
# Randomly assign tokens to experts.
|
59 |
+
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
|
60 |
+
bin_ids, indices = ops.sort(top_expert)
|
61 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
62 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
63 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
64 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
65 |
+
|
66 |
+
def padded_gather(
|
67 |
+
x: torch.Tensor,
|
68 |
+
indices: torch.Tensor,
|
69 |
+
bin_ids: torch.Tensor,
|
70 |
+
bins: torch.Tensor,
|
71 |
+
padded_bins: torch.Tensor,
|
72 |
+
top_k: int,
|
73 |
+
):
|
74 |
+
x = x.cpu().numpy()
|
75 |
+
indices = indices.cpu().numpy()
|
76 |
+
bin_ids = bin_ids.cpu().numpy()
|
77 |
+
bins = bins.cpu().numpy()
|
78 |
+
padded_bins = padded_bins.cpu().numpy()
|
79 |
+
|
80 |
+
out = np.zeros((padded_bins[-1], hs))
|
81 |
+
in_idx = 0
|
82 |
+
for i, end in enumerate(bins):
|
83 |
+
out_idx = 0 if i == 0 else padded_bins[i - 1]
|
84 |
+
end = bins[i]
|
85 |
+
while in_idx < end:
|
86 |
+
load_idx = indices[in_idx] // top_k
|
87 |
+
out[out_idx, :] = x[load_idx, :]
|
88 |
+
in_idx += 1
|
89 |
+
out_idx += 1
|
90 |
+
return torch.from_numpy(out).cuda().half()
|
91 |
+
|
92 |
+
out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
|
93 |
+
expected_out = padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
|
94 |
+
assert torch.all(torch.eq(out, expected_out))
|
tests/ops/padded_scatter_test.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from megablocks import ops
|
9 |
+
|
10 |
+
PADDED_SCATTER_TESTS = [
|
11 |
+
(4, 2, 2, 2),
|
12 |
+
(4, 2, 2, 1),
|
13 |
+
(4, 2, 2, 1),
|
14 |
+
(4, 2, 2, 1),
|
15 |
+
(4, 2, 2, 2),
|
16 |
+
(4, 2, 2, 2),
|
17 |
+
(1024, 1, 4, 1),
|
18 |
+
(1024, 1, 4, 2),
|
19 |
+
(1024, 1, 4, 4),
|
20 |
+
(1024, 1, 4, 1),
|
21 |
+
(1024, 1, 4, 2),
|
22 |
+
(1024, 1, 4, 4),
|
23 |
+
(1024, 1, 4, 1),
|
24 |
+
(1024, 1, 4, 2),
|
25 |
+
(1024, 1, 4, 4),
|
26 |
+
(1024, 1, 64, 1),
|
27 |
+
(1024, 1, 64, 2),
|
28 |
+
(1024, 1, 64, 4),
|
29 |
+
(1024, 1, 128, 1),
|
30 |
+
(1024, 1, 128, 2),
|
31 |
+
(1024, 1, 128, 4),
|
32 |
+
(1024, 1536, 4, 1),
|
33 |
+
(1024, 1536, 4, 2),
|
34 |
+
(1024, 1536, 4, 4),
|
35 |
+
(1024, 1536, 4, 4),
|
36 |
+
(1024, 1536, 4, 4),
|
37 |
+
(1024, 1536, 64, 1),
|
38 |
+
(1024, 1536, 64, 2),
|
39 |
+
(1024, 1536, 64, 4),
|
40 |
+
(1024, 1536, 128, 1),
|
41 |
+
(1024, 1536, 128, 2),
|
42 |
+
(1024, 1536, 128, 4),
|
43 |
+
(1024, 1536, 128, 1),
|
44 |
+
(1024, 1536, 128, 1),
|
45 |
+
(16384, 768, 4, 1),
|
46 |
+
(16384, 768, 4, 2),
|
47 |
+
(16384, 768, 4, 4),
|
48 |
+
(16384, 768, 64, 1),
|
49 |
+
(16384, 768, 64, 2),
|
50 |
+
(16384, 768, 64, 4),
|
51 |
+
(16384, 768, 128, 1),
|
52 |
+
(16384, 768, 128, 2),
|
53 |
+
(16384, 768, 128, 4),
|
54 |
+
(16384, 1, 4, 1),
|
55 |
+
(16384, 1, 4, 2),
|
56 |
+
(16384, 1, 4, 4),
|
57 |
+
(16384, 1, 64, 1),
|
58 |
+
(16384, 1, 64, 2),
|
59 |
+
(16384, 1, 64, 4),
|
60 |
+
(16384, 1, 128, 1),
|
61 |
+
(16384, 1, 128, 2),
|
62 |
+
(16384, 1, 128, 4),
|
63 |
+
(16384, 1, 128, 2),
|
64 |
+
(16384, 1, 128, 2),
|
65 |
+
]
|
66 |
+
|
67 |
+
|
68 |
+
def _to_numpy(x: torch.Tensor) -> np.ndarray:
|
69 |
+
return x.detach().cpu().numpy()
|
70 |
+
|
71 |
+
|
72 |
+
@pytest.mark.gpu
|
73 |
+
@pytest.mark.parametrize((
|
74 |
+
'sl',
|
75 |
+
'hs',
|
76 |
+
'ne',
|
77 |
+
'top_k',
|
78 |
+
), PADDED_SCATTER_TESTS)
|
79 |
+
def testPaddedScatter(sl: int, hs: int, ne: int, top_k: int):
|
80 |
+
# Create the data and indices.
|
81 |
+
x = torch.randn((sl, hs), requires_grad=True).cuda().half()
|
82 |
+
|
83 |
+
# Randomly assign tokens to experts.
|
84 |
+
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
|
85 |
+
bin_ids, indices = ops.sort(top_expert)
|
86 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
87 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
88 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
89 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
90 |
+
|
91 |
+
# Sample weights for the scatter reduce.
|
92 |
+
weights = torch.rand((sl * top_k,), requires_grad=True).cuda().half()
|
93 |
+
|
94 |
+
# Gather the data to prepare for backwards.
|
95 |
+
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
|
96 |
+
|
97 |
+
def padded_scatter(
|
98 |
+
x: torch.Tensor,
|
99 |
+
indices: torch.Tensor,
|
100 |
+
bin_ids: torch.Tensor,
|
101 |
+
weights: torch.Tensor,
|
102 |
+
bins: torch.Tensor,
|
103 |
+
padded_bins: torch.Tensor,
|
104 |
+
top_k: int,
|
105 |
+
):
|
106 |
+
x = x.detach().cpu().numpy()
|
107 |
+
indices: np.ndarray = _to_numpy(indices)
|
108 |
+
bin_ids: np.ndarray = _to_numpy(bin_ids)
|
109 |
+
weights: np.ndarray = _to_numpy(weights)
|
110 |
+
bins: np.ndarray = _to_numpy(bins)
|
111 |
+
padded_bins: np.ndarray = _to_numpy(padded_bins)
|
112 |
+
|
113 |
+
out = np.zeros((indices.shape[0] // top_k, hs))
|
114 |
+
out_idx = 0
|
115 |
+
for i in range(len(bins)):
|
116 |
+
in_idx = 0 if i == 0 else padded_bins[i - 1]
|
117 |
+
end = bins[i]
|
118 |
+
while out_idx < end:
|
119 |
+
store_idx = indices[out_idx]
|
120 |
+
scale = weights[store_idx]
|
121 |
+
store_idx //= top_k
|
122 |
+
|
123 |
+
out[store_idx, :] += scale * x[in_idx, :]
|
124 |
+
out_idx += 1
|
125 |
+
in_idx += 1
|
126 |
+
return torch.from_numpy(out).cuda().half()
|
127 |
+
|
128 |
+
out = ops.padded_scatter(
|
129 |
+
x,
|
130 |
+
indices,
|
131 |
+
bin_ids,
|
132 |
+
weights,
|
133 |
+
bins,
|
134 |
+
padded_bins,
|
135 |
+
top_k,
|
136 |
+
)
|
137 |
+
expected_out = padded_scatter(
|
138 |
+
x,
|
139 |
+
indices,
|
140 |
+
bin_ids,
|
141 |
+
weights,
|
142 |
+
bins,
|
143 |
+
padded_bins,
|
144 |
+
top_k,
|
145 |
+
)
|
146 |
+
|
147 |
+
out.backward(torch.randn_like(out)) # sanity check backward pass
|
148 |
+
|
149 |
+
# NOTE: We need to check approximate equality because the scatter reduce uses atomics.
|
150 |
+
# np.testing.assert_allclose returns `None` if no error and raises an AssertionError if an error exists
|
151 |
+
assert np.testing.assert_allclose(
|
152 |
+
_to_numpy(out),
|
153 |
+
_to_numpy(expected_out),
|
154 |
+
rtol=5e-3,
|
155 |
+
) is None
|
tests/ops/replicate_test.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
try:
|
9 |
+
from megablocks._ops import ops as backend # type: ignore
|
10 |
+
except ModuleNotFoundError as e:
|
11 |
+
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
12 |
+
|
13 |
+
from megablocks import ops
|
14 |
+
|
15 |
+
|
16 |
+
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
17 |
+
return x.view(1) if not len(x.size()) else x
|
18 |
+
|
19 |
+
|
20 |
+
REPLICATE_TESTS = [
|
21 |
+
(8, 1, 1),
|
22 |
+
(8, 2, 1),
|
23 |
+
(8, 4, 1),
|
24 |
+
(8, 8, 1),
|
25 |
+
(8, 2, 2),
|
26 |
+
(8, 4, 2),
|
27 |
+
(8, 8, 2),
|
28 |
+
(8, 2, 4),
|
29 |
+
(8, 4, 4),
|
30 |
+
(8, 8, 4),
|
31 |
+
(8, 2, 8),
|
32 |
+
(8, 4, 8),
|
33 |
+
(8, 8, 8),
|
34 |
+
(16384, 2, 1),
|
35 |
+
(16384, 4, 1),
|
36 |
+
(16384, 8, 1),
|
37 |
+
(16384, 16, 1),
|
38 |
+
(16384, 32, 1),
|
39 |
+
(16384, 64, 1),
|
40 |
+
(16384, 128, 1),
|
41 |
+
(16384, 2, 2),
|
42 |
+
(16384, 4, 2),
|
43 |
+
(16384, 8, 2),
|
44 |
+
(16384, 16, 2),
|
45 |
+
(16384, 32, 2),
|
46 |
+
(16384, 64, 2),
|
47 |
+
(16384, 128, 2),
|
48 |
+
(16384, 2, 4),
|
49 |
+
(16384, 4, 4),
|
50 |
+
(16384, 8, 4),
|
51 |
+
(16384, 16, 4),
|
52 |
+
(16384, 32, 4),
|
53 |
+
(16384, 64, 4),
|
54 |
+
(16384, 128, 4),
|
55 |
+
(16384, 2, 8),
|
56 |
+
(16384, 4, 8),
|
57 |
+
(16384, 8, 8),
|
58 |
+
(16384, 16, 8),
|
59 |
+
(16384, 32, 8),
|
60 |
+
(16384, 64, 8),
|
61 |
+
(16384, 128, 8),
|
62 |
+
]
|
63 |
+
|
64 |
+
|
65 |
+
@pytest.mark.gpu
|
66 |
+
@pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
|
67 |
+
def test_replicate(tokens: int, num_centers: int, top_k: int):
|
68 |
+
tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
|
69 |
+
tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
|
70 |
+
bins = ops.inclusive_cumsum(tokens_per_center, 0)
|
71 |
+
bins = promote_scalar(bins)
|
72 |
+
center_weights = torch.randn(top_k, num_centers).cuda().half()
|
73 |
+
|
74 |
+
def replicate(x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
|
75 |
+
x = x.cpu().numpy()
|
76 |
+
bins = bins.cpu().numpy()
|
77 |
+
out = np.zeros((x.shape[0], num_outputs))
|
78 |
+
for batch_idx in range(x.shape[0]):
|
79 |
+
start = 0
|
80 |
+
for i, end in enumerate(bins):
|
81 |
+
value = x[batch_idx, i]
|
82 |
+
while start < end:
|
83 |
+
out[batch_idx, start] = value
|
84 |
+
start += 1
|
85 |
+
return torch.from_numpy(out).cuda().half()
|
86 |
+
|
87 |
+
out = ops.replicate(center_weights, bins, tokens)
|
88 |
+
expected_out = replicate(center_weights, bins, tokens)
|
89 |
+
assert torch.all(torch.eq(out, expected_out))
|
90 |
+
|
91 |
+
|
92 |
+
@pytest.mark.gpu
|
93 |
+
@pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
|
94 |
+
def test_replicate_backward(tokens: int, num_centers: int, top_k: int):
|
95 |
+
tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
|
96 |
+
tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
|
97 |
+
bins = ops.inclusive_cumsum(tokens_per_center, 0)
|
98 |
+
bins = promote_scalar(bins)
|
99 |
+
center_weights = torch.randn(top_k, num_centers).cuda().half()
|
100 |
+
|
101 |
+
grad = ops.replicate(center_weights, bins, tokens)
|
102 |
+
|
103 |
+
out = torch.empty_like(center_weights)
|
104 |
+
backend.replicate_backward(grad, bins, out)
|
105 |
+
expected_out = center_weights * tokens_per_center.view([1, num_centers])
|
106 |
+
|
107 |
+
# NOTE: This floating-point reduction could be a problem for training stability and accuracy.
|
108 |
+
assert torch.allclose(out, expected_out, rtol=1e-2)
|
tests/ops/sort_test.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Dict, Optional, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pytest
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from megablocks import ops
|
11 |
+
|
12 |
+
SORT_TESTS = [
|
13 |
+
(32, torch.int16, None),
|
14 |
+
(1024, torch.int16, None),
|
15 |
+
(16384, torch.int16, None),
|
16 |
+
(32, torch.int32, None),
|
17 |
+
(1024, torch.int32, None),
|
18 |
+
(16384, torch.int32, None),
|
19 |
+
(32, torch.int64, None),
|
20 |
+
(1024, torch.int64, None),
|
21 |
+
(16384, torch.int64, None),
|
22 |
+
(32, torch.int16, 128),
|
23 |
+
(1024, torch.int16, 128),
|
24 |
+
(16384, torch.int16, 128),
|
25 |
+
(32, torch.int32, 128),
|
26 |
+
(1024, torch.int32, 128),
|
27 |
+
(16384, torch.int32, 128),
|
28 |
+
(32, torch.int64, 128),
|
29 |
+
(1024, torch.int64, 128),
|
30 |
+
(16384, torch.int64, 128),
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
def torch_to_numpy_dtype(dtype: torch.dtype,) -> Union[np.int16, np.int32, np.int64]:
|
35 |
+
types: Dict[torch.dtype, Union[np.int16, np.int32, np.int64]] = {
|
36 |
+
torch.int16: np.int16,
|
37 |
+
torch.int32: np.int32,
|
38 |
+
torch.int64: np.int64,
|
39 |
+
}
|
40 |
+
return types[dtype]
|
41 |
+
|
42 |
+
|
43 |
+
@pytest.mark.gpu
|
44 |
+
@pytest.mark.parametrize(
|
45 |
+
('n', 'dtype', 'max_val'),
|
46 |
+
SORT_TESTS,
|
47 |
+
)
|
48 |
+
def test_sort(n: int, dtype: torch.dtype, max_val: Optional[int]):
|
49 |
+
if max_val is None:
|
50 |
+
max_val = np.iinfo(torch_to_numpy_dtype(dtype)).max
|
51 |
+
end_bit = int(np.ceil(np.log2(max_val)))
|
52 |
+
x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
|
53 |
+
|
54 |
+
out, indices = ops.sort(x, end_bit)
|
55 |
+
expected_out, expected_indices = torch.sort(x)
|
56 |
+
assert torch.all(torch.eq(out, expected_out))
|
57 |
+
|
58 |
+
# NOTE: The indices can be in different order depending
|
59 |
+
# on sort stability if multiple values in the array are
|
60 |
+
# equal.
|
61 |
+
data = torch.empty_like(x)
|
62 |
+
data.scatter_(0, indices.long(), out)
|
63 |
+
expected_data = torch.empty_like(x)
|
64 |
+
expected_data.scatter_(0, expected_indices, expected_out)
|
65 |
+
assert torch.all(torch.eq(data, expected_data))
|
tests/ops/topology_test.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from megablocks import ops
|
9 |
+
|
10 |
+
TOPOLOGY_TESTS = (
|
11 |
+
(1024, 1536, 2),
|
12 |
+
(1024, 1536, 4),
|
13 |
+
(1024, 1536, 8),
|
14 |
+
(1024, 1536, 16),
|
15 |
+
(1024, 1536, 32),
|
16 |
+
(1024, 1536, 64),
|
17 |
+
(1024, 1536, 128),
|
18 |
+
(1024, 1536, 256),
|
19 |
+
(1024, 1536, 512),
|
20 |
+
(16384, 768, 2),
|
21 |
+
(16384, 768, 4),
|
22 |
+
(16384, 768, 8),
|
23 |
+
(16384, 768, 16),
|
24 |
+
(16384, 768, 32),
|
25 |
+
(16384, 768, 64),
|
26 |
+
(16384, 768, 128),
|
27 |
+
(16384, 768, 256),
|
28 |
+
(16384, 768, 512),
|
29 |
+
(16384, 768, 1024),
|
30 |
+
(8, 14336, 8),
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
@pytest.mark.gpu
|
35 |
+
@pytest.mark.parametrize(('sl', 'hs', 'ne'), TOPOLOGY_TESTS)
|
36 |
+
def test_topology(sl: int, hs: int, ne: int):
|
37 |
+
# Create the data and indices.
|
38 |
+
blocking = 128
|
39 |
+
assert hs % blocking == 0
|
40 |
+
|
41 |
+
# Randomly assign tokens to experts.
|
42 |
+
top_expert = torch.randint(0, ne, (sl,)).cuda().int()
|
43 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
44 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, blocking)
|
45 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
46 |
+
|
47 |
+
# Dimensions for the output indices.
|
48 |
+
output_block_rows = int(padded_bins[-1]) // blocking
|
49 |
+
output_block_columns = hs // blocking
|
50 |
+
|
51 |
+
def topology(
|
52 |
+
padded_bins: torch.Tensor,
|
53 |
+
blocking: torch.Tensor,
|
54 |
+
rows: int,
|
55 |
+
columns: int,
|
56 |
+
):
|
57 |
+
padded_bins = padded_bins.cpu().numpy()
|
58 |
+
|
59 |
+
out = np.zeros([rows * columns])
|
60 |
+
start = 0
|
61 |
+
for i in range(padded_bins.shape[0]):
|
62 |
+
end = padded_bins[i] // blocking
|
63 |
+
while start < end:
|
64 |
+
for j in range(columns):
|
65 |
+
out[start * columns + j] = j + i * columns
|
66 |
+
start += 1
|
67 |
+
return torch.from_numpy(out).cuda().short()
|
68 |
+
|
69 |
+
out = ops.topology(
|
70 |
+
padded_bins,
|
71 |
+
blocking,
|
72 |
+
output_block_rows,
|
73 |
+
output_block_columns,
|
74 |
+
)
|
75 |
+
expected_out = topology(
|
76 |
+
padded_bins,
|
77 |
+
blocking,
|
78 |
+
output_block_rows,
|
79 |
+
output_block_columns,
|
80 |
+
)
|
81 |
+
assert torch.all(torch.eq(out, expected_out))
|
tests/test_mb_moe.py
CHANGED
@@ -1,6 +1,48 @@
|
|
|
|
1 |
import megablocks
|
2 |
|
3 |
def test_import():
|
4 |
"""Simple test to check if the module can be imported."""
|
5 |
print("megablocks_moe module imported successfully.")
|
6 |
print("Available functions:", dir(megablocks))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
import megablocks
|
3 |
|
4 |
def test_import():
|
5 |
"""Simple test to check if the module can be imported."""
|
6 |
print("megablocks_moe module imported successfully.")
|
7 |
print("Available functions:", dir(megablocks))
|
8 |
+
|
9 |
+
expected_functions = [
|
10 |
+
"Arguments", "MLP", "MoE", "ParallelDroplessMLP", "ParallelMLP",
|
11 |
+
"SparseGLU", "SparseMLP", "argsort",
|
12 |
+
"backend", "cumsum", "dMoE", "exclusive_cumsum",
|
13 |
+
"get_load_balancing_loss", "grouped_gemm_util", "histogram",
|
14 |
+
"inclusive_cumsum", "indices", "layers", "ops", "replicate_backward",
|
15 |
+
"replicate_forward", "sort", "torch"
|
16 |
+
]
|
17 |
+
|
18 |
+
# Check if all expected functions are available
|
19 |
+
for func in expected_functions:
|
20 |
+
assert func in dir(megablocks), f"Missing function: {func}"
|
21 |
+
|
22 |
+
# exclusive_cumsum
|
23 |
+
def test_exclusive_cumsum():
|
24 |
+
"""Test exclusive cumulative sum."""
|
25 |
+
x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
|
26 |
+
out = torch.empty_like(x)
|
27 |
+
megablocks.exclusive_cumsum(x, 0, out)
|
28 |
+
expected = torch.tensor([0, 1, 3, 6], dtype=torch.float32).cuda()
|
29 |
+
assert torch.equal(out, expected), f"Expected {expected}, got {out}"
|
30 |
+
print("cumsum output:", out)
|
31 |
+
|
32 |
+
# inclusive_cumsum
|
33 |
+
def test_inclusive_cumsum():
|
34 |
+
"""Test inclusive cumulative sum."""
|
35 |
+
x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
|
36 |
+
out = torch.empty_like(x)
|
37 |
+
megablocks.inclusive_cumsum(x, dim=0, out=out)
|
38 |
+
expected = torch.tensor([1, 3, 6, 10], dtype=torch.float32).cuda()
|
39 |
+
assert torch.equal(out, expected), f"Expected {expected}, got {out}"
|
40 |
+
|
41 |
+
# histogram
|
42 |
+
def test_histogram():
|
43 |
+
"""Test histogram operation."""
|
44 |
+
x = torch.tensor([0, 1, 1, 2, 2, 2], dtype=torch.int16).cuda()
|
45 |
+
num_bins = 3
|
46 |
+
hist = megablocks.histogram(x, num_bins)
|
47 |
+
expected_hist = torch.tensor([1, 2, 3], dtype=torch.int32).cuda()
|
48 |
+
assert torch.equal(hist, expected_hist), f"Expected {expected_hist}, got {hist}"
|
torch-ext/megablocks/__init__.py
CHANGED
@@ -24,7 +24,9 @@ def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tens
|
|
24 |
Returns:
|
25 |
The output tensor
|
26 |
"""
|
27 |
-
|
|
|
|
|
28 |
|
29 |
|
30 |
def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
@@ -39,7 +41,9 @@ def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tens
|
|
39 |
Returns:
|
40 |
The output tensor
|
41 |
"""
|
42 |
-
|
|
|
|
|
43 |
|
44 |
|
45 |
def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
|
|
|
24 |
Returns:
|
25 |
The output tensor
|
26 |
"""
|
27 |
+
result = ops.exclusive_cumsum(x, dim)
|
28 |
+
out.copy_(result)
|
29 |
+
return out
|
30 |
|
31 |
|
32 |
def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
|
|
41 |
Returns:
|
42 |
The output tensor
|
43 |
"""
|
44 |
+
result = ops.inclusive_cumsum(x, dim)
|
45 |
+
out.copy_(result)
|
46 |
+
return out
|
47 |
|
48 |
|
49 |
def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
|
torch-ext/megablocks/ops/cumsum.py
CHANGED
@@ -11,7 +11,7 @@ import torch
|
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
# import megablocks_ops as ops # type: ignore
|
14 |
-
|
15 |
except ModuleNotFoundError as e:
|
16 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
17 |
|
|
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
# import megablocks_ops as ops # type: ignore
|
14 |
+
from megablocks._ops import ops # type: ignore
|
15 |
except ModuleNotFoundError as e:
|
16 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
17 |
|
torch-ext/megablocks/ops/histogram.py
CHANGED
@@ -10,7 +10,7 @@ import torch
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
-
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
+
from megablocks._ops import ops # type: ignore
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
torch-ext/megablocks/ops/replicate.py
CHANGED
@@ -10,8 +10,7 @@ import torch
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
-
|
14 |
-
import megablocks._ops as ops # type: ignore
|
15 |
except ModuleNotFoundError as e:
|
16 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
17 |
|
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
+
from megablocks._ops import ops # type: ignore
|
|
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
torch-ext/megablocks/ops/sort.py
CHANGED
@@ -10,8 +10,7 @@ import torch
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
-
|
14 |
-
import megablocks._ops as ops # type: ignore
|
15 |
except ModuleNotFoundError as e:
|
16 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
17 |
|
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
+
from megablocks._ops import ops # type: ignore
|
|
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
torch-ext/megablocks/ops/topology.py
CHANGED
@@ -10,8 +10,7 @@ import torch
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
-
|
14 |
-
import megablocks._ops as ops # type: ignore
|
15 |
except ModuleNotFoundError as e:
|
16 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
17 |
|
|
|
10 |
# Wrap this in a try-block with better error message and
|
11 |
# instructions for building the c++ operations.
|
12 |
try:
|
13 |
+
from megablocks._ops import ops # type: ignore
|
|
|
14 |
except ModuleNotFoundError as e:
|
15 |
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
|
torch-ext/torch_binding.cpp
CHANGED
@@ -34,22 +34,22 @@ torch::Tensor histogram_wrapper(torch::Tensor x, int64_t num_bins) {
|
|
34 |
torch::Tensor indices_wrapper(torch::Tensor padded_bins,
|
35 |
int64_t block_size,
|
36 |
int64_t output_block_rows,
|
37 |
-
int64_t output_block_columns
|
38 |
-
|
39 |
megablocks::indices(padded_bins, block_size, output_block_rows, output_block_columns, out);
|
40 |
return out;
|
41 |
}
|
42 |
|
43 |
|
44 |
|
45 |
-
//
|
46 |
-
//
|
47 |
-
//
|
48 |
-
//
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
|
54 |
// // Backward pass: reduce gradients back to bins using segmented reduction
|
55 |
// void replicate_backward(torch::Tensor grad,
|
@@ -90,11 +90,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
90 |
ops.def("histogram(Tensor x, int num_bins) -> Tensor");
|
91 |
ops.impl("histogram", torch::kCUDA, &histogram_wrapper);
|
92 |
|
93 |
-
ops.def("indices(Tensor padded_bins, int block_size, int output_block_rows, int output_block_columns) -> Tensor");
|
94 |
ops.impl("indices", torch::kCUDA, &indices_wrapper);
|
95 |
|
96 |
-
|
97 |
-
|
98 |
|
99 |
ops.def("replicate_backward(Tensor grad, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
|
100 |
ops.impl("replicate_backward", torch::kCUDA, &replicate_backward_wrapper);
|
|
|
34 |
torch::Tensor indices_wrapper(torch::Tensor padded_bins,
|
35 |
int64_t block_size,
|
36 |
int64_t output_block_rows,
|
37 |
+
int64_t output_block_columns,
|
38 |
+
torch::Tensor out) {
|
39 |
megablocks::indices(padded_bins, block_size, output_block_rows, output_block_columns, out);
|
40 |
return out;
|
41 |
}
|
42 |
|
43 |
|
44 |
|
45 |
+
// Forward pass: replicate values from x according to bin sizes
|
46 |
+
// void replicate_forward(torch::Tensor x,
|
47 |
+
// torch::Tensor bins,
|
48 |
+
// torch::Tensor out);
|
49 |
+
torch::Tensor replicate_forward_wrapper(torch::Tensor x, torch::Tensor bins, torch::Tensor out) {
|
50 |
+
megablocks::replicate_forward(x, bins, out);
|
51 |
+
return out;
|
52 |
+
}
|
53 |
|
54 |
// // Backward pass: reduce gradients back to bins using segmented reduction
|
55 |
// void replicate_backward(torch::Tensor grad,
|
|
|
90 |
ops.def("histogram(Tensor x, int num_bins) -> Tensor");
|
91 |
ops.impl("histogram", torch::kCUDA, &histogram_wrapper);
|
92 |
|
93 |
+
ops.def("indices(Tensor padded_bins, int block_size, int output_block_rows, int output_block_columns, Tensor(a!) out) -> Tensor(a!)");
|
94 |
ops.impl("indices", torch::kCUDA, &indices_wrapper);
|
95 |
|
96 |
+
ops.def("replicate_forward(Tensor x, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
|
97 |
+
ops.impl("replicate_forward", torch::kCUDA, &replicate_forward_wrapper);
|
98 |
|
99 |
ops.def("replicate_backward(Tensor grad, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
|
100 |
ops.impl("replicate_backward", torch::kCUDA, &replicate_backward_wrapper);
|