test_ebc / utils /data_utils.py
piaspace's picture
[first]
bb3e610
raw
history blame
2.86 kB
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms.v2 import Compose
import os, sys
from argparse import ArgumentParser
from typing import Union, Tuple
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
# import datasets
def get_dataloader(args: ArgumentParser, split: str = "train", ddp: bool = False) -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]:
if split == "train": # train, strong augmentation
transforms = Compose([
datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.min_scale, args.max_scale)),
datasets.RandomHorizontalFlip(),
datasets.RandomApply([
datasets.ColorJitter(brightness=args.brightness, contrast=args.contrast, saturation=args.saturation, hue=args.hue),
datasets.GaussianBlur(kernel_size=args.kernel_size, sigma=(0.1, 5.0)),
datasets.PepperSaltNoise(saltiness=args.saltiness, spiciness=args.spiciness),
], p=(args.jitter_prob, args.blur_prob, args.noise_prob)),
])
elif args.sliding_window:
if args.resize_to_multiple:
transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride)
elif args.zero_pad_to_multiple:
transforms = datasets.ZeroPad2Multiple(args.window_size, stride=args.stride)
else:
transforms = None
else:
transforms = None
dataset = datasets.Crowd(
dataset=args.dataset,
split=split,
transforms=transforms,
sigma=None,
return_filename=False,
num_crops=args.num_crops if split == "train" else 1,
)
if ddp and split == "train": # data_loader for training in DDP
sampler = DistributedSampler(dataset)
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
)
return data_loader, sampler
elif split == "train": # data_loader for training
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
)
return data_loader, None
else: # data_loader for evaluation
data_loader = DataLoader(
dataset,
batch_size=1, # Use batch size 1 for evaluation
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
)
return data_loader