# This file contains the changes to implement DDP training with the train.yaml config. | |
device: "$torch.device('cuda:' + os.environ['LOCAL_RANK'])" # assumes GPU # matches rank # | |
# wrap the network in a DistributedDataParallel instance, moving it to the chosen device for this process | |
network: | |
_target_: torch.nn.parallel.DistributedDataParallel | |
module: $@network_def.to(@device) | |
device_ids: ['@device'] | |
find_unused_parameters: true | |
train_sampler: | |
_target_: DistributedSampler | |
dataset: '@train_dataset' | |
even_divisible: true | |
shuffle: true | |
train_dataloader#sampler: '@train_sampler' | |
train_dataloader#shuffle: false | |
val_sampler: | |
_target_: DistributedSampler | |
dataset: '@val_dataset' | |
even_divisible: false | |
shuffle: false | |
val_dataloader#sampler: '@val_sampler' | |
initialize: | |
- $import torch.distributed as dist | |
- $dist.init_process_group(backend='nccl') | |
- $torch.cuda.set_device(@device) | |
- $monai.utils.set_determinism(seed=123) # may want to choose a different seed or not do this here | |
run: | |
- '[email protected]()' | |
finalize: | |
- '$dist.is_initialized() and dist.destroy_process_group()' | |