|
r""" Dataloader builder for few-shot semantic segmentation dataset """ |
|
from torchvision import transforms |
|
from torch.utils.data import DataLoader |
|
|
|
from fewshot_data.data.pascal import DatasetPASCAL |
|
from fewshot_data.data.coco import DatasetCOCO |
|
from fewshot_data.data.fss import DatasetFSS |
|
|
|
|
|
class FSSDataset: |
|
@classmethod |
|
def initialize(cls, img_size, datapath, use_original_imgsize, imagenet_norm=False): |
|
cls.datasets = { |
|
'pascal': DatasetPASCAL, |
|
'coco': DatasetCOCO, |
|
'fss': DatasetFSS, |
|
} |
|
|
|
if imagenet_norm: |
|
cls.img_mean = [0.485, 0.456, 0.406] |
|
cls.img_std = [0.229, 0.224, 0.225] |
|
print('use norm: {}, {}'.format(cls.img_mean, cls.img_std)) |
|
else: |
|
cls.img_mean = [0.5] * 3 |
|
cls.img_std = [0.5] * 3 |
|
print('use norm: {}, {}'.format(cls.img_mean, cls.img_std)) |
|
|
|
cls.datapath = datapath |
|
cls.use_original_imgsize = use_original_imgsize |
|
|
|
cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(cls.img_mean, cls.img_std)]) |
|
|
|
@classmethod |
|
def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1): |
|
shuffle = split == 'trn' |
|
nworker = nworker if split == 'trn' else 0 |
|
dataset = cls.datasets[benchmark](cls.datapath, fold=fold, transform=cls.transform, split=split, shot=shot, use_original_imgsize=cls.use_original_imgsize) |
|
dataloader = DataLoader(dataset, batch_size=bsz, shuffle=shuffle, num_workers=nworker) |
|
|
|
return dataloader |
|
|