test_ebc / utils /data_utils.py
piaspace's picture
[first]
bb3e610
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms.v2 import Compose
import os, sys
from argparse import ArgumentParser
from typing import Union, Tuple
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
# import datasets
def get_dataloader(args: ArgumentParser, split: str = "train", ddp: bool = False) -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]:
if split == "train": # train, strong augmentation
transforms = Compose([
datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.min_scale, args.max_scale)),
datasets.RandomHorizontalFlip(),
datasets.RandomApply([
datasets.ColorJitter(brightness=args.brightness, contrast=args.contrast, saturation=args.saturation, hue=args.hue),
datasets.GaussianBlur(kernel_size=args.kernel_size, sigma=(0.1, 5.0)),
datasets.PepperSaltNoise(saltiness=args.saltiness, spiciness=args.spiciness),
], p=(args.jitter_prob, args.blur_prob, args.noise_prob)),
])
elif args.sliding_window:
if args.resize_to_multiple:
transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride)
elif args.zero_pad_to_multiple:
transforms = datasets.ZeroPad2Multiple(args.window_size, stride=args.stride)
else:
transforms = None
else:
transforms = None
dataset = datasets.Crowd(
dataset=args.dataset,
split=split,
transforms=transforms,
sigma=None,
return_filename=False,
num_crops=args.num_crops if split == "train" else 1,
)
if ddp and split == "train": # data_loader for training in DDP
sampler = DistributedSampler(dataset)
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
)
return data_loader, sampler
elif split == "train": # data_loader for training
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
)
return data_loader, None
else: # data_loader for evaluation
data_loader = DataLoader(
dataset,
batch_size=1, # Use batch size 1 for evaluation
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
)
return data_loader