# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # based on Pytorch DistributedSampler and WeightedRandomSampler combined import math from typing import Iterator, Optional, Sequence, TypeVar import torch import torch.distributed as dist from torch.utils.data import Dataset, Sampler __all__ = ["DistributedWeightedSampler"] T_co = TypeVar("T_co", covariant=True) class DistributedWeightedSampler(Sampler[T_co]): def __init__( self, dataset: Dataset, weights: Sequence[float], num_samples: int, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, ) -> None: if not isinstance(num_samples, int) or isinstance(num_samples, bool) or num_samples <= 0: raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}") weights_tensor = torch.as_tensor(weights, dtype=torch.float) if len(weights_tensor.shape) != 1: raise ValueError( "weights should be a 1d sequence but given " f"weights have shape {tuple(weights_tensor.shape)}" ) self.weights = weights_tensor self.num_samples = num_samples if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.drop_last = drop_last self.shuffle = shuffle if self.shuffle: self.num_samples = int(math.ceil(self.num_samples / self.num_replicas)) else: # this is not used, as we always shuffle, the only reason to use this class # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed def __iter__(self) -> Iterator[T_co]: if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.multinomial(input=self.weights, num_samples=self.total_size, replacement=True, generator=g).tolist() # type: ignore[arg-type] else: # this is not used, as we always shuffle, the only reason to use this class indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) def __len__(self) -> int: return self.num_samples def set_epoch(self, epoch: int) -> None: self.epoch = epoch