# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np import pytest import torch from megablocks import ops TOPOLOGY_TESTS = ( (1024, 1536, 2), (1024, 1536, 4), (1024, 1536, 8), (1024, 1536, 16), (1024, 1536, 32), (1024, 1536, 64), (1024, 1536, 128), (1024, 1536, 256), (1024, 1536, 512), (16384, 768, 2), (16384, 768, 4), (16384, 768, 8), (16384, 768, 16), (16384, 768, 32), (16384, 768, 64), (16384, 768, 128), (16384, 768, 256), (16384, 768, 512), (16384, 768, 1024), (8, 14336, 8), ) @pytest.mark.gpu @pytest.mark.parametrize(('sl', 'hs', 'ne'), TOPOLOGY_TESTS) def test_topology(sl: int, hs: int, ne: int): # Create the data and indices. blocking = 128 assert hs % blocking == 0 # Randomly assign tokens to experts. top_expert = torch.randint(0, ne, (sl,)).cuda().int() tokens_per_expert = ops.histogram(top_expert, ne) padded_tokens_per_expert = ops.round_up(tokens_per_expert, blocking) padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) # Dimensions for the output indices. output_block_rows = int(padded_bins[-1]) // blocking output_block_columns = hs // blocking def topology( padded_bins: torch.Tensor, blocking: torch.Tensor, rows: int, columns: int, ): padded_bins = padded_bins.cpu().numpy() out = np.zeros([rows * columns]) start = 0 for i in range(padded_bins.shape[0]): end = padded_bins[i] // blocking while start < end: for j in range(columns): out[start * columns + j] = j + i * columns start += 1 return torch.from_numpy(out).cuda().short() out = ops.topology( padded_bins, blocking, output_block_rows, output_block_columns, ) expected_out = topology( padded_bins, blocking, output_block_rows, output_block_columns, ) assert torch.all(torch.eq(out, expected_out))