import torch import numpy as np from torch.utils.data.sampler import WeightedRandomSampler from .datasets import RealFakeDataset def get_bal_sampler(dataset): targets = [] for d in dataset.datasets: targets.extend(d.targets) ratio = np.bincount(targets) w = 1. / torch.tensor(ratio, dtype=torch.float) sample_weights = w[targets] sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights)) return sampler def create_dataloader(opt, preprocess=None): shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False dataset = RealFakeDataset(opt) print(len(dataset)) if '2b' in opt.arch: dataset.transform = preprocess sampler = get_bal_sampler(dataset) if opt.class_bal else None data_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=shuffle, sampler=sampler, num_workers=int(opt.num_threads)) return data_loader