File size: 4,918 Bytes
491eded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Distributed Training Utilities

This file contains utility functions for distributed training with PyTorch.
It provides tools for setting up distributed environments, efficient file handling
across processes, model unwrapping, and synchronization mechanisms to coordinate 
execution across multiple GPUs and nodes.
"""

import os
import io
from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP


def setup_dist(rank, local_rank, world_size, master_addr, master_port):
    """
    Set up the distributed training environment.
    
    Args:
        rank (int): Global rank of the current process
        local_rank (int): Local rank of the current process on this node
        world_size (int): Total number of processes in the distributed training
        master_addr (str): IP address of the master node
        master_port (str): Port on the master node for communication
    """
    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = master_port
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['RANK'] = str(rank)
    os.environ['LOCAL_RANK'] = str(local_rank)
    # Set the device for the current process
    torch.cuda.set_device(local_rank)
    # Initialize the process group for distributed communication
    dist.init_process_group('nccl', rank=rank, world_size=world_size)
    

def read_file_dist(path):
    """
    Read the binary file distributedly.
    File is only read once by the rank 0 process and broadcasted to other processes.
    This reduces I/O overhead in distributed training.

    Args:
        path (str): Path to the file to be read

    Returns:
        data (io.BytesIO): The binary data read from the file.
    """
    if dist.is_initialized() and dist.get_world_size() > 1:
        # Prepare tensor to store file size
        size = torch.LongTensor(1).cuda()
        if dist.get_rank() == 0:
            # Master process reads the file
            with open(path, 'rb') as f:
                data = f.read()
            # Convert binary data to CUDA tensor for broadcasting
            data = torch.ByteTensor(
                torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
            ).cuda()
            size[0] = data.shape[0]
        # Broadcast file size to all processes
        dist.broadcast(size, src=0)
        if dist.get_rank() != 0:
            # Non-master processes allocate buffer for receiving data
            data = torch.ByteTensor(size[0].item()).cuda()
        # Broadcast actual file data to all processes
        dist.broadcast(data, src=0)
        # Convert tensor back to binary data
        data = data.cpu().numpy().tobytes()
        data = io.BytesIO(data)
        return data
    else:
        # For non-distributed or single-process case, just read directly
        with open(path, 'rb') as f:
            data = f.read()
        data = io.BytesIO(data)
        return data
    

def unwrap_dist(model):
    """
    Unwrap the model from distributed training wrapper.
    
    Args:
        model: A potentially wrapped PyTorch model
        
    Returns:
        The underlying model without DistributedDataParallel wrapper
    """
    if isinstance(model, DDP):
        return model.module
    return model


@contextmanager
def master_first():
    """
    A context manager that ensures master process (rank 0) executes first.
    All other processes wait for the master to finish before proceeding.
    
    Usage:
        with master_first():
            # Code that should execute in master first, then others
    """
    if not dist.is_initialized():
        # If not in distributed mode, just execute normally
        yield
    else:
        if dist.get_rank() == 0:
            # Master process executes the code
            yield
            # Signal completion to other processes
            dist.barrier()
        else:
            # Other processes wait for master to finish
            dist.barrier()
            # Then execute the code
            yield
            

@contextmanager
def local_master_first():
    """
    A context manager that ensures local master process (first process on each node)
    executes first. Other processes on the same node wait before proceeding.
    
    Usage:
        with local_master_first():
            # Code that should execute in local master first, then others
    """
    if not dist.is_initialized():
        # If not in distributed mode, just execute normally
        yield
    else:
        if dist.get_rank() % torch.cuda.device_count() == 0:
            # Local master process executes the code
            yield
            # Signal completion to other processes
            dist.barrier()
        else:
            # Other processes wait for local master to finish
            dist.barrier()
            # Then execute the code
            yield