|
import logging |
|
|
|
import torch |
|
from torch.testing._internal import common_utils |
|
|
|
logging.getLogger("torch").setLevel(logging.WARNING) |
|
|
|
from apex.transformer import parallel_state |
|
from apex.transformer import tensor_parallel |
|
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase |
|
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase |
|
|
|
logging.getLogger("apex").setLevel(logging.WARNING) |
|
|
|
|
|
class TransformerRandomTestBase: |
|
def test_set_cuda_rng_state(self): |
|
for tensor_model_parallel_world_size in range(1, self.world_size + 1): |
|
if self.world_size % tensor_model_parallel_world_size: |
|
continue |
|
msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" |
|
parallel_state.initialize_model_parallel( |
|
tensor_model_parallel_size_=tensor_model_parallel_world_size |
|
) |
|
|
|
size, seed = 123, 1234 |
|
torch.cuda.manual_seed(seed) |
|
tensor = torch.cuda.FloatTensor(size) |
|
|
|
rng_state = torch.cuda.get_rng_state() |
|
rng_state_clone = rng_state.clone() |
|
|
|
for _ in range(5): |
|
torch.randn(size, out=tensor) |
|
result_1 = tensor.clone() |
|
|
|
self.assertEqual(rng_state.sub(rng_state_clone).max(), 0, msg=msg) |
|
self.assertGreater( |
|
torch.cuda.get_rng_state().sub(rng_state_clone).max(), 0, |
|
msg=msg, |
|
) |
|
|
|
new_rng_state = torch.cuda.get_rng_state() |
|
self.assertGreater(new_rng_state.sub(rng_state).max(), 0, msg=msg) |
|
|
|
tensor_parallel.random._set_cuda_rng_state(rng_state) |
|
for _ in range(5): |
|
torch.randn(size, out=tensor) |
|
tensor_parallel.random._set_cuda_rng_state(rng_state) |
|
for _ in range(5): |
|
torch.randn(size, out=tensor) |
|
result_2 = tensor.clone() |
|
|
|
self.assertEqual(result_2, result_1, msg=msg) |
|
|
|
self.assertEqual(rng_state.sub(rng_state_clone).max(), 0, msg=msg) |
|
|
|
parallel_state.destroy_model_parallel() |
|
|
|
def test_cuda_rng_tracker(self): |
|
for tensor_model_parallel_world_size in range(1, self.world_size + 1): |
|
if self.world_size % tensor_model_parallel_world_size: |
|
continue |
|
msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}" |
|
parallel_state.initialize_model_parallel( |
|
tensor_model_parallel_size_=tensor_model_parallel_world_size |
|
) |
|
|
|
seed_1, seed_2, size = 1234, 4321, [12, 21] |
|
tensor = torch.cuda.FloatTensor(size) |
|
|
|
torch.cuda.manual_seed(seed_1) |
|
torch.randn(size, out=tensor) |
|
target_11 = tensor.clone() |
|
torch.randn(size, out=tensor) |
|
target_12 = tensor.clone() |
|
|
|
torch.cuda.manual_seed(seed_2) |
|
torch.randn(size, out=tensor) |
|
targt_21 = tensor.clone() |
|
torch.randn(size, out=tensor) |
|
target_22 = tensor.clone() |
|
|
|
torch.cuda.manual_seed(seed_1) |
|
tensor_parallel.random.get_cuda_rng_tracker().add("test", seed_2) |
|
|
|
torch.randn(size, out=tensor) |
|
result_11 = tensor.clone() |
|
|
|
with tensor_parallel.random.get_cuda_rng_tracker().fork("test"): |
|
torch.randn(size, out=tensor) |
|
result_21 = tensor.clone() |
|
|
|
torch.randn(size, out=tensor) |
|
result_12 = tensor.clone() |
|
|
|
with tensor_parallel.random.get_cuda_rng_tracker().fork("test"): |
|
torch.randn(size, out=tensor) |
|
result_22 = tensor.clone() |
|
|
|
self.assertEqual(target_11, result_11, msg=msg) |
|
self.assertEqual(target_12, result_12, msg=msg) |
|
self.assertEqual(targt_21, result_21, msg=msg) |
|
self.assertEqual(target_22, result_22, msg=msg) |
|
self.assertNotEqual(result_11, result_21, msg=msg) |
|
self.assertNotEqual(result_21, result_22, msg=msg) |
|
|
|
tensor_parallel.random.get_cuda_rng_tracker().reset() |
|
parallel_state.destroy_model_parallel() |
|
|
|
|
|
class NcclTransformerRandomTest(TransformerRandomTestBase, NcclDistributedTestBase): pass |
|
class UccTransformerRandomTest(TransformerRandomTestBase, UccDistributedTestBase): pass |
|
|
|
|
|
if __name__ == "__main__": |
|
common_utils.run_tests() |
|
|