File size: 2,432 Bytes
6da2a44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import random
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
def worker_init_fn(worker_id):
# https://pytorch.org/docs/stable/notes/randomness.html#dataloader
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)
class DistInfiniteBatchSampler(Sampler):
def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=1, filling=False, shuffle=True):
assert glb_batch_size % world_size == 0
self.world_size, self.rank = world_size, rank
self.dataset_len = dataset_len
self.glb_batch_size = glb_batch_size
self.batch_size = glb_batch_size // world_size
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
self.filling = filling
self.shuffle = shuffle
self.epoch = 0
self.seed = seed
self.indices = self.gener_indices()
def gener_indices(self):
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
global_indices = torch.randperm(self.dataset_len, generator=g)
else:
global_indices = torch.arange(self.dataset_len)
filling = global_max_p - global_indices.shape[0]
if filling > 0 and self.filling:
global_indices = torch.cat((global_indices, global_indices[:filling]))
global_indices = tuple(global_indices.numpy().tolist())
seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int)
local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]]
self.max_p = len(local_indices)
return local_indices
def __iter__(self):
self.epoch = 0
while True:
self.epoch += 1
p, q = 0, 0
while p < self.max_p:
q = p + self.batch_size
yield self.indices[p:q]
p = q
if self.shuffle:
self.indices = self.gener_indices()
def __len__(self):
return self.iters_per_ep
if __name__ == '__main__':
W = 16
for rk in range(W):
ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices()
print(rk, len(ind))
|