File size: 4,919 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
r""" COCO-20i few-shot semantic segmentation dataset """
import os
import pickle
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import PIL.Image as Image
import numpy as np
class DatasetCOCO(Dataset):
def __init__(self, datapath, fold, transform, split, shot, use_original_imgsize):
self.split = 'val' if split in ['val', 'test'] else 'trn'
self.fold = fold
self.nfolds = 4
self.nclass = 80
self.benchmark = 'coco'
self.shot = shot
self.split_coco = split if split == 'val2014' else 'train2014'
self.base_path = os.path.join(datapath, 'COCO2014')
self.transform = transform
self.use_original_imgsize = use_original_imgsize
self.class_ids = self.build_class_ids()
self.img_metadata_classwise = self.build_img_metadata_classwise()
self.img_metadata = self.build_img_metadata()
def __len__(self):
return len(self.img_metadata) if self.split == 'trn' else 1000
def __getitem__(self, idx):
# ignores idx during training & testing and perform uniform sampling over object classes to form an episode
# (due to the large size of the COCO dataset)
query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize = self.load_frame()
query_img = self.transform(query_img)
query_mask = query_mask.float()
if not self.use_original_imgsize:
query_mask = F.interpolate(query_mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze()
if self.shot:
support_imgs = torch.stack([self.transform(support_img) for support_img in support_imgs])
for midx, smask in enumerate(support_masks):
support_masks[midx] = F.interpolate(smask.unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode='nearest').squeeze()
support_masks = torch.stack(support_masks)
batch = {'query_img': query_img,
'query_mask': query_mask,
'query_name': query_name,
'org_query_imsize': org_qry_imsize,
'support_imgs': support_imgs,
'support_masks': support_masks,
'support_names': support_names,
'class_id': torch.tensor(class_sample)}
return batch
def build_class_ids(self):
nclass_trn = self.nclass // self.nfolds
class_ids_val = [self.fold + self.nfolds * v for v in range(nclass_trn)]
class_ids_trn = [x for x in range(self.nclass) if x not in class_ids_val]
class_ids = class_ids_trn if self.split == 'trn' else class_ids_val
return class_ids
def build_img_metadata_classwise(self):
with open('fewshot_data/data/splits/coco/%s/fold%d.pkl' % (self.split, self.fold), 'rb') as f:
img_metadata_classwise = pickle.load(f)
return img_metadata_classwise
def build_img_metadata(self):
img_metadata = []
for k in self.img_metadata_classwise.keys():
img_metadata += self.img_metadata_classwise[k]
return sorted(list(set(img_metadata)))
def read_mask(self, name):
mask_path = os.path.join(self.base_path, 'annotations', name)
mask = torch.tensor(np.array(Image.open(mask_path[:mask_path.index('.jpg')] + '.png')))
return mask
def load_frame(self):
class_sample = np.random.choice(self.class_ids, 1, replace=False)[0]
query_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
query_img = Image.open(os.path.join(self.base_path, query_name)).convert('RGB')
query_mask = self.read_mask(query_name)
org_qry_imsize = query_img.size
query_mask[query_mask != class_sample + 1] = 0
query_mask[query_mask == class_sample + 1] = 1
support_names = []
if self.shot:
while True: # keep sampling support set if query == support
support_name = np.random.choice(self.img_metadata_classwise[class_sample], 1, replace=False)[0]
if query_name != support_name: support_names.append(support_name)
if len(support_names) == self.shot: break
support_imgs = []
support_masks = []
if self.shot:
for support_name in support_names:
support_imgs.append(Image.open(os.path.join(self.base_path, support_name)).convert('RGB'))
support_mask = self.read_mask(support_name)
support_mask[support_mask != class_sample + 1] = 0
support_mask[support_mask == class_sample + 1] = 1
support_masks.append(support_mask)
return query_img, query_mask, support_imgs, support_masks, query_name, support_names, class_sample, org_qry_imsize
|