|
import logging |
|
|
|
import torch.testing |
|
from torch.testing._internal import common_utils |
|
|
|
logging.getLogger("torch").setLevel(logging.WARNING) |
|
|
|
from apex.transformer import parallel_state |
|
from apex.transformer.tensor_parallel import data as data_utils |
|
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase |
|
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase |
|
|
|
logging.getLogger("torch").setLevel(logging.WARNING) |
|
|
|
|
|
class BroadcastDataTestBase: |
|
def test_broadcast_data(self): |
|
tensor_model_parallel_world_size: int = self.world_size // ( |
|
1 + self.world_size > 1 |
|
) |
|
parallel_state.initialize_model_parallel( |
|
tensor_model_parallel_size_=tensor_model_parallel_world_size |
|
) |
|
|
|
target_key_size = { |
|
"key1": [7, 11], |
|
"key2": [8, 2, 1], |
|
"key3": [13], |
|
"key4": [5, 1, 2], |
|
"key5": [5, 12], |
|
} |
|
keys = [k for k in target_key_size] |
|
|
|
data = {} |
|
data_t = {} |
|
with torch.no_grad(): |
|
for key in target_key_size: |
|
data[key] = torch.randint(0, 1000, size=target_key_size[key]) |
|
data_t[key] = data[key].clone() |
|
|
|
data["key_x"] = torch.rand(5) |
|
data_t["key_x"] = data["key_x"].clone() |
|
if parallel_state.get_tensor_model_parallel_rank() != 0: |
|
data = None |
|
|
|
data_utils._check_data_types(keys, data_t, torch.int64) |
|
key_size, _, _ = data_utils._build_key_size_numel_dictionaries(keys, data) |
|
|
|
for key in keys: |
|
self.assertEqual(target_key_size[key], key_size[key]) |
|
|
|
broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64) |
|
for key in keys: |
|
self.assertEqual(broadcasted_data[key], data_t[key].cuda()) |
|
|
|
parallel_state.destroy_model_parallel() |
|
|
|
|
|
class NcclBroadcastDataTest(BroadcastDataTestBase, NcclDistributedTestBase): pass |
|
class UccBroadcastDataTest(BroadcastDataTestBase, UccDistributedTestBase): pass |
|
|
|
|
|
if __name__ == "__main__": |
|
common_utils.run_tests() |
|
|