Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import Sampler, ConcatDataset | |
| class RandomConcatSampler(Sampler): | |
| """Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset | |
| in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. | |
| However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. | |
| For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. | |
| Args: | |
| shuffle (bool): shuffle the random sampled indices across all sub-datsets. | |
| repeat (int): repeatedly use the sampled indices multiple times for training. | |
| [arXiv:1902.05509, arXiv:1901.09335] | |
| NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples) | |
| NOTE: This sampler behaves differently with DistributedSampler. | |
| It assume the dataset is splitted across ranks instead of replicated. | |
| TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. | |
| ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 | |
| """ | |
| def __init__( | |
| self, | |
| data_source: ConcatDataset, | |
| n_samples_per_subset: int, | |
| subset_replacement: bool = True, | |
| shuffle: bool = True, | |
| repeat: int = 1, | |
| seed: int = None, | |
| ): | |
| if not isinstance(data_source, ConcatDataset): | |
| raise TypeError("data_source should be torch.utils.data.ConcatDataset") | |
| self.data_source = data_source | |
| self.n_subset = len(self.data_source.datasets) | |
| self.n_samples_per_subset = n_samples_per_subset | |
| self.n_samples = self.n_subset * self.n_samples_per_subset * repeat | |
| self.subset_replacement = subset_replacement | |
| self.repeat = repeat | |
| self.shuffle = shuffle | |
| self.generator = torch.manual_seed(seed) | |
| assert self.repeat >= 1 | |
| def __len__(self): | |
| return self.n_samples | |
| def __iter__(self): | |
| indices = [] | |
| # sample from each sub-dataset | |
| for d_idx in range(self.n_subset): | |
| low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1] | |
| high = self.data_source.cumulative_sizes[d_idx] | |
| if self.subset_replacement: | |
| rand_tensor = torch.randint( | |
| low, | |
| high, | |
| (self.n_samples_per_subset,), | |
| generator=self.generator, | |
| dtype=torch.int64, | |
| ) | |
| else: # sample without replacement | |
| len_subset = len(self.data_source.datasets[d_idx]) | |
| rand_tensor = torch.randperm(len_subset, generator=self.generator) + low | |
| if len_subset >= self.n_samples_per_subset: | |
| rand_tensor = rand_tensor[: self.n_samples_per_subset] | |
| else: # padding with replacement | |
| rand_tensor_replacement = torch.randint( | |
| low, | |
| high, | |
| (self.n_samples_per_subset - len_subset,), | |
| generator=self.generator, | |
| dtype=torch.int64, | |
| ) | |
| rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) | |
| indices.append(rand_tensor) | |
| indices = torch.cat(indices) | |
| if self.shuffle: # shuffle the sampled dataset (from multiple subsets) | |
| rand_tensor = torch.randperm(len(indices), generator=self.generator) | |
| indices = indices[rand_tensor] | |
| # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling) | |
| if self.repeat > 1: | |
| repeat_indices = [indices.clone() for _ in range(self.repeat - 1)] | |
| if self.shuffle: | |
| _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)] | |
| repeat_indices = map(_choice, repeat_indices) | |
| indices = torch.cat([indices, *repeat_indices], 0) | |
| assert indices.shape[0] == self.n_samples | |
| return iter(indices.tolist()) | |