kairunwen's picture
Update Code
57746f1
r""" PASCAL-5i few-shot semantic segmentation dataset """
import os
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import PIL.Image as Image
import numpy as np
class DatasetPASCAL(Dataset):
def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize):
self.split = 'val' if split in ['val', 'test'] else 'trn'
self.fold = fold
self.nfolds = 4
self.nclass = 20
self.benchmark = 'pascal'
self.shot = shot
self.use_original_imgsize = use_original_imgsize
self.img_path = os.path.join(datapath, 'VOC2012/JPEGImages/')
self.ann_path = os.path.join(datapath, 'VOC2012/SegmentationClassAug/')
self.transform = transform
self.class_ids = self.build_class_ids()
self.img_metadata = self.build_img_metadata()
self.img_metadata_classwise = self.build_img_metadata_classwise()
def __len__(self):
return len(self.img_metadata) if self.split == 'trn' else 1000
def __getitem__(self, idx):
idx %= len(self.img_metadata) # for testing, as n_images < 1000
query_name, support_names, class_sample = self.sample_episode(idx)
query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = self.load_frame(query_name, support_names)
query_img = self.transform(query_img)
if not self.use_original_imgsize:
query_cmask = F.interpolate(query_cmask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask.float(), class_sample)
if self.shot:
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
support_masks = []
support_ignore_idxs = []
for scmask in support_cmasks:
scmask = F.interpolate(scmask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
support_mask, support_ignore_idx = self.extract_ignore_idx(scmask, class_sample)
support_masks.append(support_mask)
support_ignore_idxs.append(support_ignore_idx)
support_masks = torch.stack(support_masks)
support_ignore_idxs = torch.stack(support_ignore_idxs)
else:
support_masks = []
support_ignore_idxs = []
batch = {'query_img': query_img,
'query_mask': query_mask,
'query_name': query_name,
'query_ignore_idx': query_ignore_idx,
'org_query_imsize': org_qry_imsize,
'support_imgs': support_imgs,
'support_masks': support_masks,
'support_names': support_names,
'support_ignore_idxs': support_ignore_idxs,
'class_id': torch.tensor(class_sample)}
return batch
def extract_ignore_idx(self, mask, class_id):
boundary = (mask / 255).floor()
mask[mask != class_id + 1] = 0
mask[mask == class_id + 1] = 1
return mask, boundary
def load_frame(self, query_name, support_names):
query_img = self.read_img(query_name)
query_mask = self.read_mask(query_name)
support_imgs = [self.read_img(name) for name in support_names]
support_masks = [self.read_mask(name) for name in support_names]
org_qry_imsize = query_img.size
return query_img, query_mask, support_imgs, support_masks, org_qry_imsize
def read_mask(self, img_name):
r"""Return segmentation mask in PIL Image"""
mask = torch.tensor(np.array(Image.open(os.path.join(self.ann_path, img_name) + '.png')))
return mask
def read_img(self, img_name):
r"""Return RGB image in PIL Image"""
return Image.open(os.path.join(self.img_path, img_name) + '.jpg')
def sample_episode(self, idx):
query_name, class_sample = self.img_metadata[idx]
support_names = []
if self.shot:
while True: # keep sampling support set if query == support
support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
if query_name != support_name: support_names.append(support_name)
if len(support_names) == self.shot: break
return query_name, support_names, class_sample
def build_class_ids(self):
nclass_trn = self.nclass // self.nfolds
class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)]
class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]
if self.split == 'trn':
return class_ids_trn
else:
return class_ids_val
def build_img_metadata(self):
def read_metadata(split, fold_id):
fold_n_metadata = os.path.join('fewshot_data/data/splits/pascal/%s/fold%d.txt' % (split, fold_id))
with open(fold_n_metadata, 'r') as f:
fold_n_metadata = f.read().split('\n')[:-1]
fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata]
return fold_n_metadata
img_metadata = []
if self.split == 'trn': # For training, read image-metadata of "the other" folds
for fold_id in range(self.nfolds):
if fold_id == self.fold: # Skip validation fold
continue
img_metadata += read_metadata(self.split, fold_id)
elif self.split == 'val': # For validation, read image-metadata of "current" fold
img_metadata = read_metadata(self.split, self.fold)
else:
raise Exception('Undefined split %s: ' % self.split)
print('Total (%s) images are : %d' % (self.split, len(img_metadata)))
return img_metadata
def build_img_metadata_classwise(self):
img_metadata_classwise = {}
for class_id in range(self.nclass):
img_metadata_classwise[class_id] = []
for img_name, img_class in self.img_metadata:
img_metadata_classwise[img_class] += [img_name]
return img_metadata_classwise