kernel
File size: 2,554 Bytes
9c4ca75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from megablocks import ops

_HISTOGRAM_TESTS = (
    (1, 32, torch.int16, 128),
    (1, 1024, torch.int16, 128),
    (1, 16384, torch.int16, 128),
    (1, 32, torch.int32, 128),
    (1, 1024, torch.int32, 128),
    (1, 16384, torch.int32, 128),
    (1, 32, torch.int64, 128),
    (1, 1024, torch.int64, 128),
    (1, 16384, torch.int64, 128),
    (1, 32, torch.int16, 1024),
    (1, 1024, torch.int16, 1024),
    (1, 16384, torch.int16, 1024),
    (1, 32, torch.int32, 1024),
    (1, 1024, torch.int32, 1024),
    (1, 16384, torch.int32, 1024),
    (1, 32, torch.int64, 1024),
    (1, 1024, torch.int64, 1024),
    (1, 16384, torch.int64, 1024),
    (2, 32, torch.int16, 128),
    (2, 1024, torch.int16, 128),
    (2, 16384, torch.int16, 128),
    (2, 32, torch.int32, 128),
    (2, 1024, torch.int32, 128),
    (2, 16384, torch.int32, 128),
    (2, 32, torch.int64, 128),
    (2, 1024, torch.int64, 128),
    (2, 16384, torch.int64, 128),
    (2, 32, torch.int16, 1024),
    (2, 1024, torch.int16, 1024),
    (2, 16384, torch.int16, 1024),
    (2, 32, torch.int32, 1024),
    (2, 1024, torch.int32, 1024),
    (2, 16384, torch.int32, 1024),
    (2, 32, torch.int64, 1024),
    (2, 1024, torch.int64, 1024),
    (2, 16384, torch.int64, 1024),
    (8, 32, torch.int16, 128),
    (8, 1024, torch.int16, 128),
    (8, 16384, torch.int16, 128),
    (8, 32, torch.int32, 128),
    (8, 1024, torch.int32, 128),
    (8, 16384, torch.int32, 128),
    (8, 32, torch.int64, 128),
    (8, 1024, torch.int64, 128),
    (8, 16384, torch.int64, 128),
    (8, 32, torch.int16, 1024),
    (8, 1024, torch.int16, 1024),
    (8, 16384, torch.int16, 1024),
    (8, 32, torch.int32, 1024),
    (8, 1024, torch.int32, 1024),
    (8, 16384, torch.int32, 1024),
    (8, 32, torch.int64, 1024),
    (8, 1024, torch.int64, 1024),
    (8, 16384, torch.int64, 1024),
)


# Override the seed_all fixture in autouse.py because
# _histc_cuda does not have a deterministic implementation
@pytest.fixture()
def seed_all():
    torch.use_deterministic_algorithms(False)
    return


@pytest.mark.gpu
@pytest.mark.parametrize(('m', 'n', 'dtype', 'max_val'), _HISTOGRAM_TESTS)
def test_histogram(m: int, n: int, dtype: torch.dtype, max_val: int):
    x = torch.randint(0, max_val, (m, n)).cuda().to(dtype)

    out = ops.histogram(x, max_val)
    expected_out = torch.stack([torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)])
    assert torch.all(torch.eq(out, expected_out))