omnipart's picture
init
491eded
"""
This file (`data_utils.py`) provides utility functions and classes for data handling in deep learning models.
It includes tools for moving tensors to specific devices, load-balancing utilities for distributed training,
and custom samplers for PyTorch DataLoaders that support resumable training and balanced data distribution.
Key components:
- Recursive device transfer functionality
- Load balancing utilities for distributing data across processes
- Cyclical iteration through data loaders
- Custom resumable samplers for distributed training
"""
from typing import *
import math
import torch
import numpy as np
from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler
import torch.distributed as dist
def recursive_to_device(
data: Any,
device: torch.device,
non_blocking: bool = False,
) -> Any:
"""
Recursively move all tensors in a data structure to a device.
This function traverses nested data structures (lists, tuples, dictionaries)
and moves any PyTorch tensor to the specified device.
Args:
data: The data structure containing tensors to be moved
device: The target device (CPU, GPU) to move tensors to
non_blocking: If True, allows asynchronous copy to device if possible
Returns:
The same data structure with all tensors moved to the specified device
"""
if hasattr(data, "to"):
# print("Moving data to device")
# print(data)
return data.to(device, non_blocking=non_blocking)
elif isinstance(data, (list, tuple)):
# print("list or tuple detected")
return type(data)(recursive_to_device(d, device, non_blocking) for d in data)
elif isinstance(data, dict):
# print("dict detected")
return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()}
else:
# print(f"{type(data)} detected")
return data
def load_balanced_group_indices(
load: List[int],
num_groups: int,
equal_size: bool = False,
) -> List[List[int]]:
"""
Split indices into groups with balanced load.
This function distributes indices across groups to achieve balanced workload.
It uses a greedy algorithm that assigns each index to the group with the
minimum current load.
Args:
load: List of load values for each index
num_groups: Number of groups to split indices into
equal_size: If True, each group will have the same number of elements
Returns:
List of lists, where each inner list contains indices assigned to a group
"""
if equal_size:
group_size = len(load) // num_groups
indices = np.argsort(load)[::-1] # Sort indices by load in descending order
groups = [[] for _ in range(num_groups)]
group_load = np.zeros(num_groups)
for idx in indices:
min_group_idx = np.argmin(group_load)
groups[min_group_idx].append(idx)
if equal_size and len(groups[min_group_idx]) == group_size:
group_load[min_group_idx] = float('inf') # Mark group as full
else:
group_load[min_group_idx] += load[idx]
return groups
def cycle(data_loader: DataLoader) -> Iterator:
"""
Creates an infinite iterator over a data loader.
This function wraps a data loader to cycle through it repeatedly,
handling epoch tracking for various sampler types.
Args:
data_loader: The DataLoader to cycle through
Returns:
An iterator that indefinitely yields batches from the data loader
"""
while True:
for data in data_loader:
if isinstance(data_loader.sampler, ResumableSampler):
data_loader.sampler.idx += data_loader.batch_size # Update position for resumability
yield data
if isinstance(data_loader.sampler, DistributedSampler):
data_loader.sampler.epoch += 1 # Update epoch for DistributedSampler
if isinstance(data_loader.sampler, ResumableSampler):
data_loader.sampler.epoch += 1 # Update epoch for ResumableSampler
data_loader.sampler.idx = 0 # Reset position index
class ResumableSampler(Sampler):
"""
Distributed sampler that is resumable.
This sampler extends PyTorch's Sampler to support resuming training from
a specific point. It tracks the current position (idx) and epoch to
enable checkpointing and resuming.
Args:
dataset: Dataset used for sampling.
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
By default, :attr:`rank` is retrieved from the current distributed
group.
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
indices.
seed (int, optional): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the replicas. Default: ``False``.
"""
def __init__(
self,
dataset: Dataset,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
self.dataset = dataset
self.epoch = 0 # Current epoch counter
self.idx = 0 # Current index position for resuming
self.drop_last = drop_last
self.world_size = dist.get_world_size() if dist.is_initialized() else 1 # Get total number of processes
self.rank = dist.get_rank() if dist.is_initialized() else 0 # Get current process rank
# Calculate number of samples per process
if self.drop_last and len(self.dataset) % self.world_size != 0:
# Split to nearest available length that is evenly divisible
# This ensures each rank receives the same amount of data
self.num_samples = math.ceil(
(len(self.dataset) - self.world_size) / self.world_size
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.world_size)
self.total_size = self.num_samples * self.world_size # Total size after padding
self.shuffle = shuffle
self.seed = seed
def __iter__(self) -> Iterator:
if self.shuffle:
# Deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
if not self.drop_last:
# Add extra samples to make it evenly divisible across processes
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size] # Reuse some samples from the beginning
else:
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
] # Repeat samples if padding_size > len(indices)
else:
# Remove tail of data to make it evenly divisible
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# Subsample according to rank for distributed training
indices = indices[self.rank : self.total_size : self.world_size]
# Resume from previous state by skipping already processed indices
indices = indices[self.idx:]
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def state_dict(self) -> Dict[str, int]:
"""
Returns the state of the sampler as a dictionary.
This enables saving the sampler state for checkpointing.
Returns:
Dictionary containing epoch and current index
"""
return {
'epoch': self.epoch,
'idx': self.idx,
}
def load_state_dict(self, state_dict):
"""
Loads the sampler state from a dictionary.
This enables restoring the sampler state from a checkpoint.
Args:
state_dict: Dictionary containing sampler state
"""
self.epoch = state_dict['epoch']
self.idx = state_dict['idx']
class BalancedResumableSampler(ResumableSampler):
"""
Distributed sampler that is resumable and balances the load among the processes.
This sampler extends ResumableSampler to distribute data across processes
in a load-balanced manner, ensuring that each process receives a similar
computational workload despite potentially varying sample processing times.
Args:
dataset: Dataset used for sampling. Must have 'loads' attribute.
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
indices.
seed (int, optional): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the replicas. Default: ``False``.
batch_size (int, optional): Size of mini-batches used for balancing. Default: 1.
"""
def __init__(
self,
dataset: Dataset,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
batch_size: int = 1,
) -> None:
assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler'
super().__init__(dataset, shuffle, seed, drop_last)
self.batch_size = batch_size
self.loads = dataset.loads # Load values for each sample in the dataset
def __iter__(self) -> Iterator:
# print(f"[BalancedResumableSampler] Starting __iter__ for rank {self.rank}, epoch {self.epoch}")
if self.shuffle:
# Deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# print(f"[BalancedResumableSampler] Shuffling with seed {self.seed + self.epoch}") # 0
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
# print(f"[BalancedResumableSampler] No shuffle, using sequential indices")
indices = list(range(len(self.dataset)))
# print(indices)
# print(f"[BalancedResumableSampler] Initial indices length: {len(indices)}") # 128
if not self.drop_last:
# Add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
# print(f"[BalancedResumableSampler] Adding padding of size {padding_size}") # 0
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# Remove tail of data to make it evenly divisible
# print(f"[BalancedResumableSampler] Dropping last, trimming to {self.total_size}")
indices = indices[: self.total_size]
# print(indices)
assert len(indices) == self.total_size
# print(f"[BalancedResumableSampler] After padding/trimming, indices length: {len(indices)}") # 128
# Balance load among processes by distributing batches based on their loads
num_batches = len(indices) // (self.batch_size * self.world_size)
# print(f"[BalancedResumableSampler] Number of batches: {num_batches}") # 16
balanced_indices = []
if len(self.loads) < len(indices):
# repeat the loads to match the indices
self.loads = self.loads * (len(indices) // len(self.loads)) + self.loads[:len(indices) % len(self.loads)]
for i in range(num_batches):
start_idx = i * self.batch_size * self.world_size
end_idx = (i + 1) * self.batch_size * self.world_size
# print("start idx", start_idx) # 0
# print("end idx", end_idx) # 8
# print("batch size", self.batch_size) # 8
# print("world size", self.world_size) # 1
batch_indices = indices[start_idx:end_idx]
# print(f"[BalancedResumableSampler] Processing batch {i+1}/{num_batches}, size: {len(batch_indices)}") #1/16 8
batch_loads = [self.loads[idx] for idx in batch_indices]
groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True)
balanced_indices.extend([batch_indices[j] for j in groups[self.rank]])
# print(f"[BalancedResumableSampler] Total balanced indices for rank {self.rank}: {len(balanced_indices)}")
# Resume from previous state
indices = balanced_indices[self.idx:]
# print(f"[BalancedResumableSampler] After resuming from idx {self.idx}, returning {len(indices)} indices")
return iter(indices)
class DuplicatedDataset(torch.utils.data.Dataset):
"""Dataset wrapper that duplicates a dataset multiple times."""
def __init__(self, dataset, repeat=1000):
"""
Initialize the duplicated dataset.
Args:
dataset: Original dataset to duplicate
repeat: Number of times to repeat the dataset
"""
self.dataset = dataset
self.repeat = repeat
self.original_length = len(dataset)
def __getitem__(self, idx):
"""Get an item from the original dataset, repeating as needed."""
return self.dataset[idx % self.original_length]
def __len__(self):
"""Return the length of the duplicated dataset."""
return self.original_length * self.repeat
def __getattr__(self, name):
"""Forward all other attribute accesses to the original dataset."""
if name == 'dataset' or name == 'repeat' or name == 'original_length':
return object.__getattribute__(self, name)
return getattr(self.dataset, name)
def save_coords_as_ply(coords, save_dir: str):
"""
Save the coordinates to a PLY file using normalization similar to voxelize.py.
Args:
file_path (str): The directory path to save the PLY file.
"""
import os
# import numpy as np
os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
# Get coordinates and convert to numpy
coords_np = coords.cpu().numpy()
# Print debug info
# print(f"Original coordinates shape: {coords_np.shape}")
# print(f"First few coordinates:\n{coords_np[:5]}")
if coords_np.shape[1] == 4:
# Extract XYZ coordinates (skip batch index at position 0)
vertices = coords_np[:, 1:4]
else:
vertices = coords_np
# Normalize coordinates to [-0.5, 0.5] like in voxelize.py
# Assuming the coordinates are in a 64³ grid
GRID_SIZE = 64
vertices = (vertices + 0.5) / GRID_SIZE - 0.5
# print(f"Normalized vertex range: min={np.min(vertices, axis=0)}, max={np.max(vertices, axis=0)}")
# Create PLY file (simplified format like in voxelize.py)
filename = os.path.join(save_dir, 'coords.ply')
try:
with open(filename, 'w') as f:
# Write header (no color, just XYZ coordinates)
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {vertices.shape[0]}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("end_header\n")
# Write vertices (no color)
for i in range(vertices.shape[0]):
f.write(f"{vertices[i, 0]} {vertices[i, 1]} {vertices[i, 2]}\n")
# print(f"PLY file saved to {filename} with {vertices.shape[0]} points")
# Verify file creation
# file_size = os.path.getsize(filename)
# print(f"File size: {file_size} bytes")
except Exception as e:
print(f"Error creating PLY file: {e}")
return filename