File size: 7,300 Bytes
ea88892 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Iterator, List, Optional, Sized, Union
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
class FixedBatchMultiSourceSampler(Sampler):
r"""Multi-Source Infinite Sampler.
According to the sampling ratio, sample data from different
datasets to form batches.
Args:
repeat (tuple): repeat factor
dataset (Sized): The dataset.
batch_size (int): Size of mini-batch.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
"""
def __init__(self,
repeat,
dataset: Sized,
batch_size: int,
shuffle: bool = True,
seed: Optional[int] = None) -> None:
assert hasattr(dataset, 'cumulative_sizes'),\
f'The dataset must be ConcatDataset, but get {dataset}'
assert isinstance(batch_size, int) and batch_size > 0, \
'batch_size must be a positive integer value, ' \
f'but got batch_size={batch_size}'
assert len(repeat) == len(dataset.cumulative_sizes), \
'The length of repeat must be equal to ' \
f'the number of datasets, but got repeat={repeat}'
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.repeat = repeat
self.cumulative_sizes = [0] + dataset.cumulative_sizes
self.batch_size = batch_size
self.seed = sync_random_seed() if seed is None else seed
self.shuffle = shuffle
self.source2inds = {
source: self._indices_of_rank(len(ds))
for source, ds in enumerate(dataset.datasets)
}
def _infinite_indices(self, sample_size: int) -> Iterator[int]:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
while True:
if self.shuffle:
yield from torch.randperm(sample_size, generator=g).tolist()
else:
yield from torch.arange(sample_size).tolist()
def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
"""Slice the infinite indices by rank."""
yield from itertools.islice(
self._infinite_indices(sample_size), self.rank, None,
self.world_size)
def __len__(self) -> int:
return len(self.dataset)
def set_epoch(self, epoch: int) -> None:
"""Not supported in `epoch-based runner."""
pass
def __iter__(self) -> Iterator[int]:
while True:
for source, repeat in enumerate(self.repeat):
for _ in range(repeat):
batch_buffer_per_source = []
while len(batch_buffer_per_source) < self.batch_size:
idx = next(self.source2inds[source])
idx += self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
yield from batch_buffer_per_source
class MultiSourceSampler(Sampler):
def __init__(self,
repeats,
dataset: Sized,
batch_sizes: list[int],
shuffle: bool = True,
seed: Optional[int] = None) -> None:
assert hasattr(dataset, 'cumulative_sizes'),\
f'The dataset must be ConcatDataset, but get {dataset}'
assert isinstance(batch_sizes, list), \
f'source_ratio must be a list, but got batch_sizes={batch_sizes}'
assert len(batch_sizes) == len(dataset.cumulative_sizes), \
'The length of batch_sizes must be equal to ' \
f'the number of datasets, but got batch_sizes={batch_sizes}'
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.cumulative_sizes = [0] + dataset.cumulative_sizes
self.batch_sizes = batch_sizes
self.seed = sync_random_seed() if seed is None else seed
self.shuffle = shuffle
self.source2inds = {
source: self._indices_of_rank(len(ds))
for source, ds in enumerate(dataset.datasets)
}
self.repeats = repeats
assert len(self.repeats) == len(self.batch_sizes)
def _infinite_indices(self, sample_size: int) -> Iterator[int]:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
while True:
if self.shuffle:
yield from torch.randperm(sample_size, generator=g).tolist()
else:
yield from torch.arange(sample_size).tolist()
def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
"""Slice the infinite indices by rank."""
yield from itertools.islice(
self._infinite_indices(sample_size), self.rank, None,
self.world_size)
def __len__(self) -> int:
return len(self.dataset)
def set_epoch(self, epoch: int) -> None:
"""Not supported in `epoch-based runner."""
pass
def __iter__(self) -> Iterator[int]:
while True:
for source, (batch_size, repeat) in enumerate(zip(self.batch_sizes, self.repeats)):
for _ in range(repeat):
batch_buffer_per_source = []
while len(batch_buffer_per_source) < batch_size:
idx = next(self.source2inds[source])
idx += self.cumulative_sizes[source]
batch_buffer_per_source.append(idx)
yield from batch_buffer_per_source
@property
def batch_size(self):
batch_size_sum = sum([batch_size * repeat for batch_size, repeat in zip(self.batch_sizes, self.repeats)])
batch_size_ave = batch_size_sum // sum(self.repeats)
return batch_size_ave
class MultiSourceBatchSampler(Sampler[list[int]]):
def __init__(
self,
sampler: Union[FixedBatchMultiSourceSampler, MultiSourceSampler],
batch_sizes: list[int],
repeats: list[int],
**kwargs
) -> None:
self.sampler = sampler
self.batch_sizes = batch_sizes
self.repeats = repeats
def __iter__(self) -> Iterator[list[int]]:
# Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
sampler_iter = iter(self.sampler)
while True:
for source, (batch_size, repeat) in enumerate(zip(self.batch_sizes, self.repeats)):
for _ in range(repeat):
batch = [*itertools.islice(sampler_iter, batch_size)]
yield batch
@property
def batch_size(self):
batch_size_sum = sum([batch_size * repeat for batch_size, repeat in zip(self.batch_sizes, self.repeats)])
batch_size_ave = batch_size_sum // sum(self.repeats)
return batch_size_ave
def __len__(self) -> int:
return len(self.sampler) // self.batch_size
|