File size: 3,337 Bytes
519d358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Distributed training utilities.

"""
import logging
import pickle

import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, Subset
from torch.nn.parallel.distributed import DistributedDataParallel

from dora import distrib as dora_distrib

logger = logging.getLogger(__name__)
rank = 0
world_size = 1


def init():
    global rank, world_size
    if not torch.distributed.is_initialized():
        dora_distrib.init()
    rank = dora_distrib.rank()
    world_size = dora_distrib.world_size()


def average(metrics, count=1.):
    if isinstance(metrics, dict):
        keys, values = zip(*sorted(metrics.items()))
        values = average(values, count)
        return dict(zip(keys, values))
    if world_size == 1:
        return metrics
    tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
    tensor *= count
    torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
    return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()


def wrap(model):
    if world_size == 1:
        return model
    else:
        return DistributedDataParallel(
            model,
            # find_unused_parameters=True,
            device_ids=[torch.cuda.current_device()],
            output_device=torch.cuda.current_device())


def barrier():
    if world_size > 1:
        torch.distributed.barrier()


def share(obj=None, src=0):
    if world_size == 1:
        return obj
    size = torch.empty(1, device='cuda', dtype=torch.long)
    if rank == src:
        dump = pickle.dumps(obj)
        size[0] = len(dump)
    torch.distributed.broadcast(size, src=src)
    # size variable is now set to the length of pickled obj in all processes

    if rank == src:
        buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda()
    else:
        buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8)
    torch.distributed.broadcast(buffer, src=src)
    # buffer variable is now set to pickled obj in all processes

    if rank != src:
        obj = pickle.loads(buffer.cpu().numpy().tobytes())
    logger.debug(f"Shared object of size {len(buffer)}")
    return obj


def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
    """

    Create a dataloader properly in case of distributed training.

    If a gradient is going to be computed you must set `shuffle=True`.

    """
    if world_size == 1:
        return klass(dataset, *args, shuffle=shuffle, **kwargs)

    if shuffle:
        # train means we will compute backward, we use DistributedSampler
        sampler = DistributedSampler(dataset)
        # We ignore shuffle, DistributedSampler already shuffles
        return klass(dataset, *args, **kwargs, sampler=sampler)
    else:
        # We make a manual shard, as DistributedSampler otherwise replicate some examples
        dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
        return klass(dataset, *args, shuffle=shuffle, **kwargs)