# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Random sampling under a constraint # -------------------------------------------------------- import numpy as np import torch from typing import Callable, Iterable, Optional from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, Sampler, BatchSampler import random def custom_collate_fn(batch): """ Custom collate function to handle variable batch sizes Args: batch: A list where each element could be either: - A single tuple (idx, num_images, ...) - A list of tuples [(idx1, num_images1, ...), (idx2, num_images2, ...)] """ # If batch contains lists (variable batch size case) breakpoint() if isinstance(batch[0], list): # Flatten the batch flattened = [] for item in batch: flattened.extend(item) batch = flattened # Now batch is a list of tuples, process normally return torch.utils.data.default_collate(batch) class BatchedRandomSampler: """Random sampling under a constraint: each sample in the batch has the same feature, which is chosen randomly from a known pool of 'features' for each batch. For instance, the 'feature' could be the image aspect-ratio. The index returned is a tuple (sample_idx, feat_idx). This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. """ def __init__( self, dataset, batch_size, num_context_views, min_patch_num=20, max_patch_num=32, world_size=1, rank=0, drop_last=True ): self.batch_size = batch_size self.num_context_views = num_context_views self.len_dataset = N = len(dataset) self.total_size = round_by(N, batch_size * world_size) if drop_last else N self.min_patch_num = min_patch_num self.max_patch_num = max_patch_num assert ( world_size == 1 or drop_last ), "must drop the last batch in distributed mode" # distributed sampler self.world_size = world_size self.rank = rank self.epoch = None def __len__(self): return self.total_size // self.world_size def set_epoch(self, epoch): self.epoch = epoch def __iter__(self): # prepare RNG if self.epoch is None: assert ( self.world_size == 1 and self.rank == 0 ), "use set_epoch() if distributed mode is used" seed = int(torch.empty((), dtype=torch.int64).random_().item()) else: seed = self.epoch + 777 rng = np.random.default_rng(seed=seed) # random indices (will restart from 0 if not drop_last) sample_idxs = np.arange(self.total_size) rng.shuffle(sample_idxs) # random feat_idxs (same across each batch) n_batches = (self.total_size + self.batch_size - 1) // self.batch_size num_imgs = rng.integers(low=2, high=self.num_context_views, size=n_batches) # num_imgs = (np.ones(n_batches) * self.num_context_views).astype(np.int64) # same number of context views for each batch num_imgs = np.broadcast_to(num_imgs[:, None], (n_batches, self.batch_size)) num_imgs = num_imgs.ravel()[: self.total_size] # put them together idxs = np.c_[sample_idxs, num_imgs] # shape = (total_size, 2) # Distributed sampler: we select a subset of batches # make sure the slice for each node is aligned with batch_size size_per_proc = self.batch_size * ( (self.total_size + self.world_size * self.batch_size - 1) // (self.world_size * self.batch_size) ) idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc] yield from (tuple(idx) for idx in idxs) class DynamicBatchSampler(Sampler): """ A custom batch sampler that dynamically adjusts batch size, aspect ratio, and image number for each sample. Batches within a sample share the same aspect ratio and image number. """ def __init__(self, sampler, image_num_range, h_range, epoch=0, seed=42, max_img_per_gpu=48): """ Initializes the dynamic batch sampler. Args: sampler: Instance of DynamicDistributedSampler. aspect_ratio_range: List containing [min_aspect_ratio, max_aspect_ratio]. image_num_range: List containing [min_images, max_images] per sample. epoch: Current epoch number. seed: Random seed for reproducibility. max_img_per_gpu: Maximum number of images to fit in GPU memory. """ self.sampler = sampler self.image_num_range = image_num_range self.h_range = h_range self.rng = random.Random() # Uniformly sample from the range of possible image numbers # For any image number, the weight is 1.0 (uniform sampling). You can set any different weights here. self.image_num_weights = {num_images: float(num_images**2) for num_images in range(image_num_range[0], image_num_range[1]+1)} # Possible image numbers, e.g., [2, 3, 4, ..., 24] self.possible_nums = np.array([n for n in self.image_num_weights.keys() if self.image_num_range[0] <= n <= self.image_num_range[1]]) # Normalize weights for sampling weights = [self.image_num_weights[n] for n in self.possible_nums] self.normalized_weights = np.array(weights) / sum(weights) # Maximum image number per GPU self.max_img_per_gpu = max_img_per_gpu # Set the epoch for the sampler self.set_epoch(epoch + seed) def set_epoch(self, epoch): """ Sets the epoch for this sampler, affecting the random sequence. Args: epoch: The epoch number. """ self.sampler.set_epoch(epoch) self.epoch = epoch self.rng.seed(epoch * 100) def __iter__(self): """ Yields batches of samples with synchronized dynamic parameters. Returns: Iterator yielding batches of indices with associated parameters. """ sampler_iterator = iter(self.sampler) while True: try: # Sample random image number and aspect ratio random_image_num = int(np.random.choice(self.possible_nums, p=self.normalized_weights)) random_ps_h = np.random.randint(low=(self.h_range[0] // 14), high=(self.h_range[1] // 14)+1) # Update sampler parameters self.sampler.update_parameters( image_num=random_image_num, ps_h=random_ps_h ) # Calculate batch size based on max images per GPU and current image number batch_size = self.max_img_per_gpu / random_image_num batch_size = np.floor(batch_size).astype(int) batch_size = max(1, batch_size) # Ensure batch size is at least 1 # Collect samples for the current batch current_batch = [] for _ in range(batch_size): try: item = next(sampler_iterator) # item is (idx, aspect_ratio, image_num) current_batch.append(item) except StopIteration: break # No more samples if not current_batch: break # No more data to yield yield current_batch except StopIteration: break # End of sampler's iterator def __len__(self): # Return a large dummy length return 1000000 class DynamicDistributedSampler(DistributedSampler): """ Extends PyTorch's DistributedSampler to include dynamic aspect_ratio and image_num parameters, which can be passed into the dataset's __getitem__ method. """ def __init__( self, dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = False, seed: int = 0, drop_last: bool = False, ): super().__init__( dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last ) self.image_num = None self.ps_h = None def __iter__(self): """ Yields a sequence of (index, image_num, aspect_ratio). Relies on the parent class's logic for shuffling/distributing the indices across replicas, then attaches extra parameters. """ indices_iter = super().__iter__() for idx in indices_iter: yield (idx, self.image_num, self.ps_h, ) def update_parameters(self, image_num, ps_h): """ Updates dynamic parameters for each new epoch or iteration. Args: aspect_ratio: The aspect ratio to set. image_num: The number of images to set. """ self.image_num = image_num self.ps_h = ps_h class MixedBatchSampler(BatchSampler): """Sample one batch from a selected dataset with given probability. Compatible with datasets at different resolution """ def __init__( self, src_dataset_ls, batch_size, num_context_views, world_size=1, rank=0, prob=None, sampler=None, generator=None ): self.base_sampler = None self.batch_size = batch_size self.num_context_views = num_context_views self.world_size = world_size self.rank = rank self.drop_last = True self.generator = generator self.src_dataset_ls = src_dataset_ls self.n_dataset = len(self.src_dataset_ls) # Dataset length self.dataset_length = [len(ds) for ds in self.src_dataset_ls] self.cum_dataset_length = [ sum(self.dataset_length[:i]) for i in range(self.n_dataset) ] # cumulative dataset length # BatchSamplers for each source dataset self.src_batch_samplers = [] for ds in self.src_dataset_ls: sampler = DynamicDistributedSampler(ds, num_replicas=self.world_size, rank=self.rank, seed=42, shuffle=True) sampler.set_epoch(0) if hasattr(ds, "epoch"): ds.epoch = 0 if hasattr(ds, "set_epoch"): ds.set_epoch(0) batch_sampler = DynamicBatchSampler( sampler, [2, ds.cfg.view_sampler.num_context_views], ds.cfg.input_image_shape, seed=42, max_img_per_gpu=ds.cfg.view_sampler.max_img_per_gpu ) self.src_batch_samplers.append(batch_sampler) # self.src_batch_samplers = [ # BatchedRandomSampler( # ds, # num_context_views=ds.cfg.view_sampler.num_context_views, # world_size=self.world_size, # rank=self.rank, # batch_size=self.batch_size, # drop_last=self.drop_last, # ) # for ds in self.src_dataset_ls # ] # set epoch here print("Setting epoch for all underlying BatchedRandomSamplers") # for sampler in self.src_batch_samplers: # sampler.set_epoch(0) self.raw_batches = [ list(bs) for bs in self.src_batch_samplers ] # index in original dataset self.n_batches = [len(b) for b in self.raw_batches] self.n_total_batch = sum(self.n_batches) # print("Total batch num is ", self.n_total_batch) # sampling probability if prob is None: # if not given, decide by dataset length self.prob = torch.tensor(self.n_batches) / self.n_total_batch else: self.prob = torch.as_tensor(prob) def __iter__(self): """Yields batches of indices in the format of (sample_idx, feat_idx) tuples, where indices correspond to ConcatDataset of src_dataset_ls """ for _ in range(self.n_total_batch): idx_ds = torch.multinomial( self.prob, 1, replacement=True, generator=self.generator ).item() if 0 == len(self.raw_batches[idx_ds]): self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds]) # get a batch from list - this is already in (sample_idx, feat_idx) format batch_raw = self.raw_batches[idx_ds].pop() # shift only the sample_idx by cumulative dataset length, keep feat_idx unchanged shift = self.cum_dataset_length[idx_ds] processed_batch = [] for item in batch_raw: # item[0] is the sample index, item[1] is the number of images processed_item = (item[0] + shift, item[1], item[2]) processed_batch.append(processed_item) yield processed_batch def set_epoch(self, epoch): """Set epoch for all underlying BatchedRandomSamplers""" for sampler in self.src_batch_samplers: sampler.set_epoch(epoch) # Reset raw_batches after setting new epoch self.raw_batches = [list(bs) for bs in self.src_batch_samplers] def __len__(self): return self.n_total_batch def round_by(total, multiple, up=False): if up: total = total + multiple - 1 return (total // multiple) * multiple