kairunwen's picture
Update Code
57746f1
r""" Dataloader builder for few-shot semantic segmentation dataset """
from torchvision import transforms
from torch.utils.data import DataLoader
from fewshot_data.data.pascal import DatasetPASCAL
from fewshot_data.data.coco import DatasetCOCO
from fewshot_data.data.fss import DatasetFSS
class FSSDataset:
@classmethod
def initialize(cls, img_size, datapath, use_original_imgsize, imagenet_norm=False):
cls.datasets = {
'pascal': DatasetPASCAL,
'coco': DatasetCOCO,
'fss': DatasetFSS,
}
if imagenet_norm:
cls.img_mean = [0.485, 0.456, 0.406]
cls.img_std = [0.229, 0.224, 0.225]
print('use norm: {}, {}'.format(cls.img_mean, cls.img_std))
else:
cls.img_mean = [0.5] * 3
cls.img_std = [0.5] * 3
print('use norm: {}, {}'.format(cls.img_mean, cls.img_std))
cls.datapath = datapath
cls.use_original_imgsize = use_original_imgsize
cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(cls.img_mean, cls.img_std)])
@classmethod
def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1):
shuffle = split == 'trn'
nworker = nworker if split == 'trn' else 0
dataset = cls.datasets[benchmark](cls.datapath, fold=fold, transform=cls.transform, split=split, shot=shot, use_original_imgsize=cls.use_original_imgsize)
dataloader = DataLoader(dataset, batch_size=bsz, shuffle=shuffle, num_workers=nworker)
return dataloader