|
""" |
|
Distributed Training Utilities |
|
|
|
This file contains utility functions for distributed training with PyTorch. |
|
It provides tools for setting up distributed environments, efficient file handling |
|
across processes, model unwrapping, and synchronization mechanisms to coordinate |
|
execution across multiple GPUs and nodes. |
|
""" |
|
|
|
import os |
|
import io |
|
from contextlib import contextmanager |
|
import torch |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
def setup_dist(rank, local_rank, world_size, master_addr, master_port): |
|
""" |
|
Set up the distributed training environment. |
|
|
|
Args: |
|
rank (int): Global rank of the current process |
|
local_rank (int): Local rank of the current process on this node |
|
world_size (int): Total number of processes in the distributed training |
|
master_addr (str): IP address of the master node |
|
master_port (str): Port on the master node for communication |
|
""" |
|
os.environ['MASTER_ADDR'] = master_addr |
|
os.environ['MASTER_PORT'] = master_port |
|
os.environ['WORLD_SIZE'] = str(world_size) |
|
os.environ['RANK'] = str(rank) |
|
os.environ['LOCAL_RANK'] = str(local_rank) |
|
|
|
torch.cuda.set_device(local_rank) |
|
|
|
dist.init_process_group('nccl', rank=rank, world_size=world_size) |
|
|
|
|
|
def read_file_dist(path): |
|
""" |
|
Read the binary file distributedly. |
|
File is only read once by the rank 0 process and broadcasted to other processes. |
|
This reduces I/O overhead in distributed training. |
|
|
|
Args: |
|
path (str): Path to the file to be read |
|
|
|
Returns: |
|
data (io.BytesIO): The binary data read from the file. |
|
""" |
|
if dist.is_initialized() and dist.get_world_size() > 1: |
|
|
|
size = torch.LongTensor(1).cuda() |
|
if dist.get_rank() == 0: |
|
|
|
with open(path, 'rb') as f: |
|
data = f.read() |
|
|
|
data = torch.ByteTensor( |
|
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8) |
|
).cuda() |
|
size[0] = data.shape[0] |
|
|
|
dist.broadcast(size, src=0) |
|
if dist.get_rank() != 0: |
|
|
|
data = torch.ByteTensor(size[0].item()).cuda() |
|
|
|
dist.broadcast(data, src=0) |
|
|
|
data = data.cpu().numpy().tobytes() |
|
data = io.BytesIO(data) |
|
return data |
|
else: |
|
|
|
with open(path, 'rb') as f: |
|
data = f.read() |
|
data = io.BytesIO(data) |
|
return data |
|
|
|
|
|
def unwrap_dist(model): |
|
""" |
|
Unwrap the model from distributed training wrapper. |
|
|
|
Args: |
|
model: A potentially wrapped PyTorch model |
|
|
|
Returns: |
|
The underlying model without DistributedDataParallel wrapper |
|
""" |
|
if isinstance(model, DDP): |
|
return model.module |
|
return model |
|
|
|
|
|
@contextmanager |
|
def master_first(): |
|
""" |
|
A context manager that ensures master process (rank 0) executes first. |
|
All other processes wait for the master to finish before proceeding. |
|
|
|
Usage: |
|
with master_first(): |
|
# Code that should execute in master first, then others |
|
""" |
|
if not dist.is_initialized(): |
|
|
|
yield |
|
else: |
|
if dist.get_rank() == 0: |
|
|
|
yield |
|
|
|
dist.barrier() |
|
else: |
|
|
|
dist.barrier() |
|
|
|
yield |
|
|
|
|
|
@contextmanager |
|
def local_master_first(): |
|
""" |
|
A context manager that ensures local master process (first process on each node) |
|
executes first. Other processes on the same node wait before proceeding. |
|
|
|
Usage: |
|
with local_master_first(): |
|
# Code that should execute in local master first, then others |
|
""" |
|
if not dist.is_initialized(): |
|
|
|
yield |
|
else: |
|
if dist.get_rank() % torch.cuda.device_count() == 0: |
|
|
|
yield |
|
|
|
dist.barrier() |
|
else: |
|
|
|
dist.barrier() |
|
|
|
yield |