vista2d / scripts /cell_distributed_weighted_sampler.py
project-monai's picture
Upload vista2d version 0.3.1
fd4ffa6 verified
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# based on Pytorch DistributedSampler and WeightedRandomSampler combined
import math
from typing import Iterator, Optional, Sequence, TypeVar
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, Sampler
__all__ = ["DistributedWeightedSampler"]
T_co = TypeVar("T_co", covariant=True)
class DistributedWeightedSampler(Sampler[T_co]):
def __init__(
self,
dataset: Dataset,
weights: Sequence[float],
num_samples: int,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
if not isinstance(num_samples, int) or isinstance(num_samples, bool) or num_samples <= 0:
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}")
weights_tensor = torch.as_tensor(weights, dtype=torch.float)
if len(weights_tensor.shape) != 1:
raise ValueError(
"weights should be a 1d sequence but given " f"weights have shape {tuple(weights_tensor.shape)}"
)
self.weights = weights_tensor
self.num_samples = num_samples
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
self.shuffle = shuffle
if self.shuffle:
self.num_samples = int(math.ceil(self.num_samples / self.num_replicas))
else:
# this is not used, as we always shuffle, the only reason to use this class
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.multinomial(input=self.weights, num_samples=self.total_size, replacement=True, generator=g).tolist() # type: ignore[arg-type]
else:
# this is not used, as we always shuffle, the only reason to use this class
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch