File size: 2,554 Bytes
9c4ca75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from megablocks import ops
_HISTOGRAM_TESTS = (
(1, 32, torch.int16, 128),
(1, 1024, torch.int16, 128),
(1, 16384, torch.int16, 128),
(1, 32, torch.int32, 128),
(1, 1024, torch.int32, 128),
(1, 16384, torch.int32, 128),
(1, 32, torch.int64, 128),
(1, 1024, torch.int64, 128),
(1, 16384, torch.int64, 128),
(1, 32, torch.int16, 1024),
(1, 1024, torch.int16, 1024),
(1, 16384, torch.int16, 1024),
(1, 32, torch.int32, 1024),
(1, 1024, torch.int32, 1024),
(1, 16384, torch.int32, 1024),
(1, 32, torch.int64, 1024),
(1, 1024, torch.int64, 1024),
(1, 16384, torch.int64, 1024),
(2, 32, torch.int16, 128),
(2, 1024, torch.int16, 128),
(2, 16384, torch.int16, 128),
(2, 32, torch.int32, 128),
(2, 1024, torch.int32, 128),
(2, 16384, torch.int32, 128),
(2, 32, torch.int64, 128),
(2, 1024, torch.int64, 128),
(2, 16384, torch.int64, 128),
(2, 32, torch.int16, 1024),
(2, 1024, torch.int16, 1024),
(2, 16384, torch.int16, 1024),
(2, 32, torch.int32, 1024),
(2, 1024, torch.int32, 1024),
(2, 16384, torch.int32, 1024),
(2, 32, torch.int64, 1024),
(2, 1024, torch.int64, 1024),
(2, 16384, torch.int64, 1024),
(8, 32, torch.int16, 128),
(8, 1024, torch.int16, 128),
(8, 16384, torch.int16, 128),
(8, 32, torch.int32, 128),
(8, 1024, torch.int32, 128),
(8, 16384, torch.int32, 128),
(8, 32, torch.int64, 128),
(8, 1024, torch.int64, 128),
(8, 16384, torch.int64, 128),
(8, 32, torch.int16, 1024),
(8, 1024, torch.int16, 1024),
(8, 16384, torch.int16, 1024),
(8, 32, torch.int32, 1024),
(8, 1024, torch.int32, 1024),
(8, 16384, torch.int32, 1024),
(8, 32, torch.int64, 1024),
(8, 1024, torch.int64, 1024),
(8, 16384, torch.int64, 1024),
)
# Override the seed_all fixture in autouse.py because
# _histc_cuda does not have a deterministic implementation
@pytest.fixture()
def seed_all():
torch.use_deterministic_algorithms(False)
return
@pytest.mark.gpu
@pytest.mark.parametrize(('m', 'n', 'dtype', 'max_val'), _HISTOGRAM_TESTS)
def test_histogram(m: int, n: int, dtype: torch.dtype, max_val: int):
x = torch.randint(0, max_val, (m, n)).cuda().to(dtype)
out = ops.histogram(x, max_val)
expected_out = torch.stack([torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)])
assert torch.all(torch.eq(out, expected_out))
|