Spaces:
Runtime error
Runtime error
from torch.utils.data import DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
from torchvision.transforms.v2 import Compose | |
import os, sys | |
from argparse import ArgumentParser | |
from typing import Union, Tuple | |
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
sys.path.append(parent_dir) | |
# import datasets | |
def get_dataloader(args: ArgumentParser, split: str = "train", ddp: bool = False) -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]: | |
if split == "train": # train, strong augmentation | |
transforms = Compose([ | |
datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.min_scale, args.max_scale)), | |
datasets.RandomHorizontalFlip(), | |
datasets.RandomApply([ | |
datasets.ColorJitter(brightness=args.brightness, contrast=args.contrast, saturation=args.saturation, hue=args.hue), | |
datasets.GaussianBlur(kernel_size=args.kernel_size, sigma=(0.1, 5.0)), | |
datasets.PepperSaltNoise(saltiness=args.saltiness, spiciness=args.spiciness), | |
], p=(args.jitter_prob, args.blur_prob, args.noise_prob)), | |
]) | |
elif args.sliding_window: | |
if args.resize_to_multiple: | |
transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride) | |
elif args.zero_pad_to_multiple: | |
transforms = datasets.ZeroPad2Multiple(args.window_size, stride=args.stride) | |
else: | |
transforms = None | |
else: | |
transforms = None | |
dataset = datasets.Crowd( | |
dataset=args.dataset, | |
split=split, | |
transforms=transforms, | |
sigma=None, | |
return_filename=False, | |
num_crops=args.num_crops if split == "train" else 1, | |
) | |
if ddp and split == "train": # data_loader for training in DDP | |
sampler = DistributedSampler(dataset) | |
data_loader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
sampler=sampler, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
collate_fn=datasets.collate_fn, | |
) | |
return data_loader, sampler | |
elif split == "train": # data_loader for training | |
data_loader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
collate_fn=datasets.collate_fn, | |
) | |
return data_loader, None | |
else: # data_loader for evaluation | |
data_loader = DataLoader( | |
dataset, | |
batch_size=1, # Use batch size 1 for evaluation | |
shuffle=False, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
collate_fn=datasets.collate_fn, | |
) | |
return data_loader | |