Spaces:
Runtime error
Runtime error
File size: 2,855 Bytes
bb3e610 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms.v2 import Compose
import os, sys
from argparse import ArgumentParser
from typing import Union, Tuple
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
# import datasets
def get_dataloader(args: ArgumentParser, split: str = "train", ddp: bool = False) -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]:
if split == "train": # train, strong augmentation
transforms = Compose([
datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.min_scale, args.max_scale)),
datasets.RandomHorizontalFlip(),
datasets.RandomApply([
datasets.ColorJitter(brightness=args.brightness, contrast=args.contrast, saturation=args.saturation, hue=args.hue),
datasets.GaussianBlur(kernel_size=args.kernel_size, sigma=(0.1, 5.0)),
datasets.PepperSaltNoise(saltiness=args.saltiness, spiciness=args.spiciness),
], p=(args.jitter_prob, args.blur_prob, args.noise_prob)),
])
elif args.sliding_window:
if args.resize_to_multiple:
transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride)
elif args.zero_pad_to_multiple:
transforms = datasets.ZeroPad2Multiple(args.window_size, stride=args.stride)
else:
transforms = None
else:
transforms = None
dataset = datasets.Crowd(
dataset=args.dataset,
split=split,
transforms=transforms,
sigma=None,
return_filename=False,
num_crops=args.num_crops if split == "train" else 1,
)
if ddp and split == "train": # data_loader for training in DDP
sampler = DistributedSampler(dataset)
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
)
return data_loader, sampler
elif split == "train": # data_loader for training
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
)
return data_loader, None
else: # data_loader for evaluation
data_loader = DataLoader(
dataset,
batch_size=1, # Use batch size 1 for evaluation
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=datasets.collate_fn,
)
return data_loader
|