Spaces:
Runtime error
Runtime error
File size: 934 Bytes
03561be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import itertools
import torch
from torch.utils.data.sampler import Sampler
from mmgpt.train.distributed import world_info_from_env
class InfiniteSampler(Sampler):
def __init__(self, dataset: int, shuffle: bool = True, seed: int = 0):
self._size = len(dataset)
self._shuffle = shuffle
self._seed = int(seed)
_, rank, world_size = world_info_from_env()
self._rank = rank
self._world_size = world_size
def __iter__(self):
start = self._rank
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed)
while True:
if self._shuffle:
yield from torch.randperm(self._size, generator=g).tolist()
else:
yield from torch.arange(self._size).tolist()
|