import os import subprocess import torch import torch.distributed as dist def setup_distributed(backend="nccl", port=None): """AdaHessian Optimizer Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py Originally licensed MIT, Copyright (c) 2020 Wei Li """ num_gpus = torch.cuda.device_count() rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) torch.cuda.set_device(rank % num_gpus) dist.init_process_group( backend=backend, world_size=world_size, rank=rank, ) return rank, world_size