AnySplat / src /dataset /data_sampler.py
alexnasa's picture
Upload 243 files
2568013 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Random sampling under a constraint
# --------------------------------------------------------
import numpy as np
import torch
from typing import Callable, Iterable, Optional
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset, Sampler, BatchSampler
import random
def custom_collate_fn(batch):
"""
Custom collate function to handle variable batch sizes
Args:
batch: A list where each element could be either:
- A single tuple (idx, num_images, ...)
- A list of tuples [(idx1, num_images1, ...), (idx2, num_images2, ...)]
"""
# If batch contains lists (variable batch size case)
breakpoint()
if isinstance(batch[0], list):
# Flatten the batch
flattened = []
for item in batch:
flattened.extend(item)
batch = flattened
# Now batch is a list of tuples, process normally
return torch.utils.data.default_collate(batch)
class BatchedRandomSampler:
"""Random sampling under a constraint: each sample in the batch has the same feature,
which is chosen randomly from a known pool of 'features' for each batch.
For instance, the 'feature' could be the image aspect-ratio.
The index returned is a tuple (sample_idx, feat_idx).
This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
"""
def __init__(
self, dataset, batch_size, num_context_views, min_patch_num=20, max_patch_num=32, world_size=1, rank=0, drop_last=True
):
self.batch_size = batch_size
self.num_context_views = num_context_views
self.len_dataset = N = len(dataset)
self.total_size = round_by(N, batch_size * world_size) if drop_last else N
self.min_patch_num = min_patch_num
self.max_patch_num = max_patch_num
assert (
world_size == 1 or drop_last
), "must drop the last batch in distributed mode"
# distributed sampler
self.world_size = world_size
self.rank = rank
self.epoch = None
def __len__(self):
return self.total_size // self.world_size
def set_epoch(self, epoch):
self.epoch = epoch
def __iter__(self):
# prepare RNG
if self.epoch is None:
assert (
self.world_size == 1 and self.rank == 0
), "use set_epoch() if distributed mode is used"
seed = int(torch.empty((), dtype=torch.int64).random_().item())
else:
seed = self.epoch + 777
rng = np.random.default_rng(seed=seed)
# random indices (will restart from 0 if not drop_last)
sample_idxs = np.arange(self.total_size)
rng.shuffle(sample_idxs)
# random feat_idxs (same across each batch)
n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
num_imgs = rng.integers(low=2, high=self.num_context_views, size=n_batches)
# num_imgs = (np.ones(n_batches) * self.num_context_views).astype(np.int64) # same number of context views for each batch
num_imgs = np.broadcast_to(num_imgs[:, None], (n_batches, self.batch_size))
num_imgs = num_imgs.ravel()[: self.total_size]
# put them together
idxs = np.c_[sample_idxs, num_imgs] # shape = (total_size, 2)
# Distributed sampler: we select a subset of batches
# make sure the slice for each node is aligned with batch_size
size_per_proc = self.batch_size * (
(self.total_size + self.world_size * self.batch_size - 1)
// (self.world_size * self.batch_size)
)
idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc]
yield from (tuple(idx) for idx in idxs)
class DynamicBatchSampler(Sampler):
"""
A custom batch sampler that dynamically adjusts batch size, aspect ratio, and image number
for each sample. Batches within a sample share the same aspect ratio and image number.
"""
def __init__(self,
sampler,
image_num_range,
h_range,
epoch=0,
seed=42,
max_img_per_gpu=48):
"""
Initializes the dynamic batch sampler.
Args:
sampler: Instance of DynamicDistributedSampler.
aspect_ratio_range: List containing [min_aspect_ratio, max_aspect_ratio].
image_num_range: List containing [min_images, max_images] per sample.
epoch: Current epoch number.
seed: Random seed for reproducibility.
max_img_per_gpu: Maximum number of images to fit in GPU memory.
"""
self.sampler = sampler
self.image_num_range = image_num_range
self.h_range = h_range
self.rng = random.Random()
# Uniformly sample from the range of possible image numbers
# For any image number, the weight is 1.0 (uniform sampling). You can set any different weights here.
self.image_num_weights = {num_images: float(num_images**2) for num_images in range(image_num_range[0], image_num_range[1]+1)}
# Possible image numbers, e.g., [2, 3, 4, ..., 24]
self.possible_nums = np.array([n for n in self.image_num_weights.keys()
if self.image_num_range[0] <= n <= self.image_num_range[1]])
# Normalize weights for sampling
weights = [self.image_num_weights[n] for n in self.possible_nums]
self.normalized_weights = np.array(weights) / sum(weights)
# Maximum image number per GPU
self.max_img_per_gpu = max_img_per_gpu
# Set the epoch for the sampler
self.set_epoch(epoch + seed)
def set_epoch(self, epoch):
"""
Sets the epoch for this sampler, affecting the random sequence.
Args:
epoch: The epoch number.
"""
self.sampler.set_epoch(epoch)
self.epoch = epoch
self.rng.seed(epoch * 100)
def __iter__(self):
"""
Yields batches of samples with synchronized dynamic parameters.
Returns:
Iterator yielding batches of indices with associated parameters.
"""
sampler_iterator = iter(self.sampler)
while True:
try:
# Sample random image number and aspect ratio
random_image_num = int(np.random.choice(self.possible_nums, p=self.normalized_weights))
random_ps_h = np.random.randint(low=(self.h_range[0] // 14), high=(self.h_range[1] // 14)+1)
# Update sampler parameters
self.sampler.update_parameters(
image_num=random_image_num,
ps_h=random_ps_h
)
# Calculate batch size based on max images per GPU and current image number
batch_size = self.max_img_per_gpu / random_image_num
batch_size = np.floor(batch_size).astype(int)
batch_size = max(1, batch_size) # Ensure batch size is at least 1
# Collect samples for the current batch
current_batch = []
for _ in range(batch_size):
try:
item = next(sampler_iterator) # item is (idx, aspect_ratio, image_num)
current_batch.append(item)
except StopIteration:
break # No more samples
if not current_batch:
break # No more data to yield
yield current_batch
except StopIteration:
break # End of sampler's iterator
def __len__(self):
# Return a large dummy length
return 1000000
class DynamicDistributedSampler(DistributedSampler):
"""
Extends PyTorch's DistributedSampler to include dynamic aspect_ratio and image_num
parameters, which can be passed into the dataset's __getitem__ method.
"""
def __init__(
self,
dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = False,
seed: int = 0,
drop_last: bool = False,
):
super().__init__(
dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last
)
self.image_num = None
self.ps_h = None
def __iter__(self):
"""
Yields a sequence of (index, image_num, aspect_ratio).
Relies on the parent class's logic for shuffling/distributing
the indices across replicas, then attaches extra parameters.
"""
indices_iter = super().__iter__()
for idx in indices_iter:
yield (idx, self.image_num, self.ps_h, )
def update_parameters(self, image_num, ps_h):
"""
Updates dynamic parameters for each new epoch or iteration.
Args:
aspect_ratio: The aspect ratio to set.
image_num: The number of images to set.
"""
self.image_num = image_num
self.ps_h = ps_h
class MixedBatchSampler(BatchSampler):
"""Sample one batch from a selected dataset with given probability.
Compatible with datasets at different resolution
"""
def __init__(
self, src_dataset_ls, batch_size, num_context_views, world_size=1, rank=0, prob=None, sampler=None, generator=None
):
self.base_sampler = None
self.batch_size = batch_size
self.num_context_views = num_context_views
self.world_size = world_size
self.rank = rank
self.drop_last = True
self.generator = generator
self.src_dataset_ls = src_dataset_ls
self.n_dataset = len(self.src_dataset_ls)
# Dataset length
self.dataset_length = [len(ds) for ds in self.src_dataset_ls]
self.cum_dataset_length = [
sum(self.dataset_length[:i]) for i in range(self.n_dataset)
] # cumulative dataset length
# BatchSamplers for each source dataset
self.src_batch_samplers = []
for ds in self.src_dataset_ls:
sampler = DynamicDistributedSampler(ds, num_replicas=self.world_size, rank=self.rank, seed=42, shuffle=True)
sampler.set_epoch(0)
if hasattr(ds, "epoch"):
ds.epoch = 0
if hasattr(ds, "set_epoch"):
ds.set_epoch(0)
batch_sampler = DynamicBatchSampler(
sampler,
[2, ds.cfg.view_sampler.num_context_views],
ds.cfg.input_image_shape,
seed=42,
max_img_per_gpu=ds.cfg.view_sampler.max_img_per_gpu
)
self.src_batch_samplers.append(batch_sampler)
# self.src_batch_samplers = [
# BatchedRandomSampler(
# ds,
# num_context_views=ds.cfg.view_sampler.num_context_views,
# world_size=self.world_size,
# rank=self.rank,
# batch_size=self.batch_size,
# drop_last=self.drop_last,
# )
# for ds in self.src_dataset_ls
# ]
# set epoch here
print("Setting epoch for all underlying BatchedRandomSamplers")
# for sampler in self.src_batch_samplers:
# sampler.set_epoch(0)
self.raw_batches = [
list(bs) for bs in self.src_batch_samplers
] # index in original dataset
self.n_batches = [len(b) for b in self.raw_batches]
self.n_total_batch = sum(self.n_batches)
# print("Total batch num is ", self.n_total_batch)
# sampling probability
if prob is None:
# if not given, decide by dataset length
self.prob = torch.tensor(self.n_batches) / self.n_total_batch
else:
self.prob = torch.as_tensor(prob)
def __iter__(self):
"""Yields batches of indices in the format of (sample_idx, feat_idx) tuples,
where indices correspond to ConcatDataset of src_dataset_ls
"""
for _ in range(self.n_total_batch):
idx_ds = torch.multinomial(
self.prob, 1, replacement=True, generator=self.generator
).item()
if 0 == len(self.raw_batches[idx_ds]):
self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds])
# get a batch from list - this is already in (sample_idx, feat_idx) format
batch_raw = self.raw_batches[idx_ds].pop()
# shift only the sample_idx by cumulative dataset length, keep feat_idx unchanged
shift = self.cum_dataset_length[idx_ds]
processed_batch = []
for item in batch_raw:
# item[0] is the sample index, item[1] is the number of images
processed_item = (item[0] + shift, item[1], item[2])
processed_batch.append(processed_item)
yield processed_batch
def set_epoch(self, epoch):
"""Set epoch for all underlying BatchedRandomSamplers"""
for sampler in self.src_batch_samplers:
sampler.set_epoch(epoch)
# Reset raw_batches after setting new epoch
self.raw_batches = [list(bs) for bs in self.src_batch_samplers]
def __len__(self):
return self.n_total_batch
def round_by(total, multiple, up=False):
if up:
total = total + multiple - 1
return (total // multiple) * multiple