File size: 6,268 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
r""" PASCAL-5i few-shot semantic segmentation dataset """
import os
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import PIL.Image as Image
import numpy as np
class DatasetPASCAL(Dataset):
def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize):
self.split = 'val' if split in ['val', 'test'] else 'trn'
self.fold = fold
self.nfolds = 4
self.nclass = 20
self.benchmark = 'pascal'
self.shot = shot
self.use_original_imgsize = use_original_imgsize
self.img_path = os.path.join(datapath, 'VOC2012/JPEGImages/')
self.ann_path = os.path.join(datapath, 'VOC2012/SegmentationClassAug/')
self.transform = transform
self.class_ids = self.build_class_ids()
self.img_metadata = self.build_img_metadata()
self.img_metadata_classwise = self.build_img_metadata_classwise()
def __len__(self):
return len(self.img_metadata) if self.split == 'trn' else 1000
def __getitem__(self, idx):
idx %= len(self.img_metadata) # for testing, as n_images < 1000
query_name, support_names, class_sample = self.sample_episode(idx)
query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = self.load_frame(query_name, support_names)
query_img = self.transform(query_img)
if not self.use_original_imgsize:
query_cmask = F.interpolate(query_cmask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask.float(), class_sample)
if self.shot:
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
support_masks = []
support_ignore_idxs = []
for scmask in support_cmasks:
scmask = F.interpolate(scmask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
support_mask, support_ignore_idx = self.extract_ignore_idx(scmask, class_sample)
support_masks.append(support_mask)
support_ignore_idxs.append(support_ignore_idx)
support_masks = torch.stack(support_masks)
support_ignore_idxs = torch.stack(support_ignore_idxs)
else:
support_masks = []
support_ignore_idxs = []
batch = {'query_img': query_img,
'query_mask': query_mask,
'query_name': query_name,
'query_ignore_idx': query_ignore_idx,
'org_query_imsize': org_qry_imsize,
'support_imgs': support_imgs,
'support_masks': support_masks,
'support_names': support_names,
'support_ignore_idxs': support_ignore_idxs,
'class_id': torch.tensor(class_sample)}
return batch
def extract_ignore_idx(self, mask, class_id):
boundary = (mask / 255).floor()
mask[mask != class_id + 1] = 0
mask[mask == class_id + 1] = 1
return mask, boundary
def load_frame(self, query_name, support_names):
query_img = self.read_img(query_name)
query_mask = self.read_mask(query_name)
support_imgs = [self.read_img(name) for name in support_names]
support_masks = [self.read_mask(name) for name in support_names]
org_qry_imsize = query_img.size
return query_img, query_mask, support_imgs, support_masks, org_qry_imsize
def read_mask(self, img_name):
r"""Return segmentation mask in PIL Image"""
mask = torch.tensor(np.array(Image.open(os.path.join(self.ann_path, img_name) + '.png')))
return mask
def read_img(self, img_name):
r"""Return RGB image in PIL Image"""
return Image.open(os.path.join(self.img_path, img_name) + '.jpg')
def sample_episode(self, idx):
query_name, class_sample = self.img_metadata[idx]
support_names = []
if self.shot:
while True: # keep sampling support set if query == support
support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
if query_name != support_name: support_names.append(support_name)
if len(support_names) == self.shot: break
return query_name, support_names, class_sample
def build_class_ids(self):
nclass_trn = self.nclass // self.nfolds
class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)]
class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]
if self.split == 'trn':
return class_ids_trn
else:
return class_ids_val
def build_img_metadata(self):
def read_metadata(split, fold_id):
fold_n_metadata = os.path.join('fewshot_data/data/splits/pascal/%s/fold%d.txt' % (split, fold_id))
with open(fold_n_metadata, 'r') as f:
fold_n_metadata = f.read().split('\n')[:-1]
fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata]
return fold_n_metadata
img_metadata = []
if self.split == 'trn': # For training, read image-metadata of "the other" folds
for fold_id in range(self.nfolds):
if fold_id == self.fold: # Skip validation fold
continue
img_metadata += read_metadata(self.split, fold_id)
elif self.split == 'val': # For validation, read image-metadata of "current" fold
img_metadata = read_metadata(self.split, self.fold)
else:
raise Exception('Undefined split %s: ' % self.split)
print('Total (%s) images are : %d' % (self.split, len(img_metadata)))
return img_metadata
def build_img_metadata_classwise(self):
img_metadata_classwise = {}
for class_id in range(self.nclass):
img_metadata_classwise[class_id] = []
for img_name, img_class in self.img_metadata:
img_metadata_classwise[img_class] += [img_name]
return img_metadata_classwise
|