qitaoz's picture
Upload 57 files
4562a06 verified
raw
history blame contribute delete
771 Bytes
import os
import socket
from contextlib import closing
import torch.distributed as dist
def get_open_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
# Distributed process group
def ddp_setup(rank, world_size, port="12345"):
"""
Args:
rank: Unique Identifier
world_size: number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
print(f"MasterPort: {str(port)}")
os.environ["MASTER_PORT"] = str(port)
# initialize the process group
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()