|
import torch |
|
import numpy as np |
|
from torch.utils.data.sampler import WeightedRandomSampler |
|
|
|
from .datasets import RealFakeDataset |
|
|
|
|
|
|
|
def get_bal_sampler(dataset): |
|
targets = [] |
|
for d in dataset.datasets: |
|
targets.extend(d.targets) |
|
|
|
ratio = np.bincount(targets) |
|
w = 1. / torch.tensor(ratio, dtype=torch.float) |
|
sample_weights = w[targets] |
|
sampler = WeightedRandomSampler(weights=sample_weights, |
|
num_samples=len(sample_weights)) |
|
return sampler |
|
|
|
|
|
def create_dataloader(opt, preprocess=None): |
|
shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False |
|
dataset = RealFakeDataset(opt) |
|
print(len(dataset)) |
|
if '2b' in opt.arch: |
|
dataset.transform = preprocess |
|
sampler = get_bal_sampler(dataset) if opt.class_bal else None |
|
|
|
data_loader = torch.utils.data.DataLoader(dataset, |
|
batch_size=opt.batch_size, |
|
shuffle=shuffle, |
|
sampler=sampler, |
|
num_workers=int(opt.num_threads)) |
|
return data_loader |
|
|