Spaces:
Runtime error
Runtime error
import itertools | |
import torch | |
from torch.utils.data.sampler import Sampler | |
from mmgpt.train.distributed import world_info_from_env | |
class InfiniteSampler(Sampler): | |
def __init__(self, dataset: int, shuffle: bool = True, seed: int = 0): | |
self._size = len(dataset) | |
self._shuffle = shuffle | |
self._seed = int(seed) | |
_, rank, world_size = world_info_from_env() | |
self._rank = rank | |
self._world_size = world_size | |
def __iter__(self): | |
start = self._rank | |
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) | |
def _infinite_indices(self): | |
g = torch.Generator() | |
g.manual_seed(self._seed) | |
while True: | |
if self._shuffle: | |
yield from torch.randperm(self._size, generator=g).tolist() | |
else: | |
yield from torch.arange(self._size).tolist() | |