from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torchvision.transforms.v2 import Compose import os, sys from argparse import ArgumentParser from typing import Union, Tuple parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) # import datasets def get_dataloader(args: ArgumentParser, split: str = "train", ddp: bool = False) -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]: if split == "train": # train, strong augmentation transforms = Compose([ datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.min_scale, args.max_scale)), datasets.RandomHorizontalFlip(), datasets.RandomApply([ datasets.ColorJitter(brightness=args.brightness, contrast=args.contrast, saturation=args.saturation, hue=args.hue), datasets.GaussianBlur(kernel_size=args.kernel_size, sigma=(0.1, 5.0)), datasets.PepperSaltNoise(saltiness=args.saltiness, spiciness=args.spiciness), ], p=(args.jitter_prob, args.blur_prob, args.noise_prob)), ]) elif args.sliding_window: if args.resize_to_multiple: transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride) elif args.zero_pad_to_multiple: transforms = datasets.ZeroPad2Multiple(args.window_size, stride=args.stride) else: transforms = None else: transforms = None dataset = datasets.Crowd( dataset=args.dataset, split=split, transforms=transforms, sigma=None, return_filename=False, num_crops=args.num_crops if split == "train" else 1, ) if ddp and split == "train": # data_loader for training in DDP sampler = DistributedSampler(dataset) data_loader = DataLoader( dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=datasets.collate_fn, ) return data_loader, sampler elif split == "train": # data_loader for training data_loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, collate_fn=datasets.collate_fn, ) return data_loader, None else: # data_loader for evaluation data_loader = DataLoader( dataset, batch_size=1, # Use batch size 1 for evaluation shuffle=False, num_workers=args.num_workers, pin_memory=True, collate_fn=datasets.collate_fn, ) return data_loader