Spaces:
Sleeping
Sleeping
| from __future__ import division | |
| import os | |
| from glob import glob | |
| import json | |
| import random | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms as TF | |
| import dataloaders.image_transforms as IT | |
| cv2.setNumThreads(0) | |
| def _get_images(sample): | |
| return [sample['ref_img'], sample['prev_img']] + sample['curr_img'] | |
| def _get_labels(sample): | |
| return [sample['ref_label'], sample['prev_label']] + sample['curr_label'] | |
| def _merge_sample(sample1, sample2, min_obj_pixels=100, max_obj_n=10): | |
| sample1_images = _get_images(sample1) | |
| sample2_images = _get_images(sample2) | |
| sample1_labels = _get_labels(sample1) | |
| sample2_labels = _get_labels(sample2) | |
| obj_idx = torch.arange(0, max_obj_n * 2 + 1).view(max_obj_n * 2 + 1, 1, 1) | |
| selected_idx = None | |
| selected_obj = None | |
| all_img = [] | |
| all_mask = [] | |
| for idx, (s1_img, s2_img, s1_label, s2_label) in enumerate( | |
| zip(sample1_images, sample2_images, sample1_labels, | |
| sample2_labels)): | |
| s2_fg = (s2_label > 0).float() | |
| s2_bg = 1 - s2_fg | |
| merged_img = s1_img * s2_bg + s2_img * s2_fg | |
| merged_mask = s1_label * s2_bg.long() + ( | |
| (s2_label + max_obj_n) * s2_fg.long()) | |
| merged_mask = (merged_mask == obj_idx).float() | |
| if idx == 0: | |
| after_merge_pixels = merged_mask.sum(dim=(1, 2), keepdim=True) | |
| selected_idx = after_merge_pixels > min_obj_pixels | |
| selected_idx[0] = True | |
| obj_num = selected_idx.sum().int().item() - 1 | |
| selected_idx = selected_idx.expand(-1, | |
| s1_label.size()[1], | |
| s1_label.size()[2]) | |
| if obj_num > max_obj_n: | |
| selected_obj = list(range(1, obj_num + 1)) | |
| random.shuffle(selected_obj) | |
| selected_obj = [0] + selected_obj[:max_obj_n] | |
| merged_mask = merged_mask[selected_idx].view(obj_num + 1, | |
| s1_label.size()[1], | |
| s1_label.size()[2]) | |
| if obj_num > max_obj_n: | |
| merged_mask = merged_mask[selected_obj] | |
| merged_mask[0] += 0.1 | |
| merged_mask = torch.argmax(merged_mask, dim=0, keepdim=True).long() | |
| all_img.append(merged_img) | |
| all_mask.append(merged_mask) | |
| sample = { | |
| 'ref_img': all_img[0], | |
| 'prev_img': all_img[1], | |
| 'curr_img': all_img[2:], | |
| 'ref_label': all_mask[0], | |
| 'prev_label': all_mask[1], | |
| 'curr_label': all_mask[2:] | |
| } | |
| sample['meta'] = sample1['meta'] | |
| sample['meta']['obj_num'] = min(obj_num, max_obj_n) | |
| return sample | |
| class StaticTrain(Dataset): | |
| def __init__(self, | |
| root, | |
| output_size, | |
| seq_len=5, | |
| max_obj_n=10, | |
| dynamic_merge=True, | |
| merge_prob=1.0, | |
| aug_type='v1'): | |
| self.root = root | |
| self.clip_n = seq_len | |
| self.output_size = output_size | |
| self.max_obj_n = max_obj_n | |
| self.dynamic_merge = dynamic_merge | |
| self.merge_prob = merge_prob | |
| self.img_list = list() | |
| self.mask_list = list() | |
| dataset_list = list() | |
| lines = ['COCO', 'ECSSD', 'MSRA10K', 'PASCAL-S', 'PASCALVOC2012'] | |
| for line in lines: | |
| dataset_name = line.strip() | |
| img_dir = os.path.join(root, 'JPEGImages', dataset_name) | |
| mask_dir = os.path.join(root, 'Annotations', dataset_name) | |
| img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) + \ | |
| sorted(glob(os.path.join(img_dir, '*.png'))) | |
| mask_list = sorted(glob(os.path.join(mask_dir, '*.png'))) | |
| if len(img_list) > 0: | |
| if len(img_list) == len(mask_list): | |
| dataset_list.append(dataset_name) | |
| self.img_list += img_list | |
| self.mask_list += mask_list | |
| print(f'\t{dataset_name}: {len(img_list)} imgs.') | |
| else: | |
| print( | |
| f'\tPreTrain dataset {dataset_name} has {len(img_list)} imgs and {len(mask_list)} annots. Not match! Skip.' | |
| ) | |
| else: | |
| print( | |
| f'\tPreTrain dataset {dataset_name} doesn\'t exist. Skip.') | |
| print( | |
| f'{len(self.img_list)} imgs are used for PreTrain. They are from {dataset_list}.' | |
| ) | |
| self.aug_type = aug_type | |
| self.pre_random_horizontal_flip = IT.RandomHorizontalFlip(0.5) | |
| self.random_horizontal_flip = IT.RandomHorizontalFlip(0.3) | |
| if self.aug_type == 'v1': | |
| self.color_jitter = TF.ColorJitter(0.1, 0.1, 0.1, 0.03) | |
| elif self.aug_type == 'v2': | |
| self.color_jitter = TF.RandomApply( | |
| [TF.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8) | |
| self.gray_scale = TF.RandomGrayscale(p=0.2) | |
| self.blur = TF.RandomApply([IT.GaussianBlur([.1, 2.])], p=0.3) | |
| else: | |
| assert NotImplementedError | |
| self.random_affine = IT.RandomAffine(degrees=20, | |
| translate=(0.1, 0.1), | |
| scale=(0.9, 1.1), | |
| shear=10, | |
| resample=Image.BICUBIC, | |
| fillcolor=(124, 116, 104)) | |
| base_ratio = float(output_size[1]) / output_size[0] | |
| self.random_resize_crop = IT.RandomResizedCrop( | |
| output_size, (0.8, 1), | |
| ratio=(base_ratio * 3. / 4., base_ratio * 4. / 3.), | |
| interpolation=Image.BICUBIC) | |
| self.to_tensor = TF.ToTensor() | |
| self.to_onehot = IT.ToOnehot(max_obj_n, shuffle=True) | |
| self.normalize = TF.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225)) | |
| def __len__(self): | |
| return len(self.img_list) | |
| def load_image_in_PIL(self, path, mode='RGB'): | |
| img = Image.open(path) | |
| img.load() # Very important for loading large image | |
| return img.convert(mode) | |
| def sample_sequence(self, idx): | |
| img_pil = self.load_image_in_PIL(self.img_list[idx], 'RGB') | |
| mask_pil = self.load_image_in_PIL(self.mask_list[idx], 'P') | |
| frames = [] | |
| masks = [] | |
| img_pil, mask_pil = self.pre_random_horizontal_flip(img_pil, mask_pil) | |
| # img_pil, mask_pil = self.pre_random_vertical_flip(img_pil, mask_pil) | |
| for i in range(self.clip_n): | |
| img, mask = img_pil, mask_pil | |
| if i > 0: | |
| img, mask = self.random_horizontal_flip(img, mask) | |
| img, mask = self.random_affine(img, mask) | |
| img = self.color_jitter(img) | |
| img, mask = self.random_resize_crop(img, mask) | |
| if self.aug_type == 'v2': | |
| img = self.gray_scale(img) | |
| img = self.blur(img) | |
| mask = np.array(mask, np.uint8) | |
| if i == 0: | |
| mask, obj_list = self.to_onehot(mask) | |
| obj_num = len(obj_list) | |
| else: | |
| mask, _ = self.to_onehot(mask, obj_list) | |
| mask = torch.argmax(mask, dim=0, keepdim=True) | |
| frames.append(self.normalize(self.to_tensor(img))) | |
| masks.append(mask) | |
| sample = { | |
| 'ref_img': frames[0], | |
| 'prev_img': frames[1], | |
| 'curr_img': frames[2:], | |
| 'ref_label': masks[0], | |
| 'prev_label': masks[1], | |
| 'curr_label': masks[2:] | |
| } | |
| sample['meta'] = { | |
| 'seq_name': self.img_list[idx], | |
| 'frame_num': 1, | |
| 'obj_num': obj_num | |
| } | |
| return sample | |
| def __getitem__(self, idx): | |
| sample1 = self.sample_sequence(idx) | |
| if self.dynamic_merge and (sample1['meta']['obj_num'] == 0 | |
| or random.random() < self.merge_prob): | |
| rand_idx = np.random.randint(len(self.img_list)) | |
| while (rand_idx == idx): | |
| rand_idx = np.random.randint(len(self.img_list)) | |
| sample2 = self.sample_sequence(rand_idx) | |
| sample = self.merge_sample(sample1, sample2) | |
| else: | |
| sample = sample1 | |
| return sample | |
| def merge_sample(self, sample1, sample2, min_obj_pixels=100): | |
| return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n) | |
| class VOSTrain(Dataset): | |
| def __init__(self, | |
| image_root, | |
| label_root, | |
| imglistdic, | |
| transform=None, | |
| rgb=True, | |
| repeat_time=1, | |
| rand_gap=3, | |
| seq_len=5, | |
| rand_reverse=True, | |
| dynamic_merge=True, | |
| enable_prev_frame=False, | |
| merge_prob=0.3, | |
| max_obj_n=10): | |
| self.image_root = image_root | |
| self.label_root = label_root | |
| self.rand_gap = rand_gap | |
| self.seq_len = seq_len | |
| self.rand_reverse = rand_reverse | |
| self.repeat_time = repeat_time | |
| self.transform = transform | |
| self.dynamic_merge = dynamic_merge | |
| self.merge_prob = merge_prob | |
| self.enable_prev_frame = enable_prev_frame | |
| self.max_obj_n = max_obj_n | |
| self.rgb = rgb | |
| self.imglistdic = imglistdic | |
| self.seqs = list(self.imglistdic.keys()) | |
| print('Video Num: {} X {}'.format(len(self.seqs), self.repeat_time)) | |
| def __len__(self): | |
| return int(len(self.seqs) * self.repeat_time) | |
| def reverse_seq(self, imagelist, lablist): | |
| if np.random.randint(2) == 1: | |
| imagelist = imagelist[::-1] | |
| lablist = lablist[::-1] | |
| return imagelist, lablist | |
| def get_ref_index(self, | |
| seqname, | |
| lablist, | |
| objs, | |
| min_fg_pixels=200, | |
| max_try=5): | |
| bad_indices = [] | |
| for _ in range(max_try): | |
| ref_index = np.random.randint(len(lablist)) | |
| if ref_index in bad_indices: | |
| continue | |
| ref_label = Image.open( | |
| os.path.join(self.label_root, seqname, lablist[ref_index])) | |
| ref_label = np.array(ref_label, dtype=np.uint8) | |
| ref_objs = list(np.unique(ref_label)) | |
| is_consistent = True | |
| for obj in ref_objs: | |
| if obj == 0: | |
| continue | |
| if obj not in objs: | |
| is_consistent = False | |
| xs, ys = np.nonzero(ref_label) | |
| if len(xs) > min_fg_pixels and is_consistent: | |
| break | |
| bad_indices.append(ref_index) | |
| return ref_index | |
| def get_ref_index_v2(self, | |
| seqname, | |
| lablist, | |
| min_fg_pixels=200, | |
| max_try=20, | |
| total_gap=0): | |
| search_range = len(lablist) - total_gap | |
| if search_range <= 1: | |
| return 0 | |
| bad_indices = [] | |
| for _ in range(max_try): | |
| ref_index = np.random.randint(search_range) | |
| if ref_index in bad_indices: | |
| continue | |
| ref_label = Image.open( | |
| os.path.join(self.label_root, seqname, lablist[ref_index])) | |
| ref_label = np.array(ref_label, dtype=np.uint8) | |
| xs, ys = np.nonzero(ref_label) | |
| if len(xs) > min_fg_pixels: | |
| break | |
| bad_indices.append(ref_index) | |
| return ref_index | |
| def get_curr_gaps(self, seq_len, max_gap=999, max_try=10): | |
| for _ in range(max_try): | |
| curr_gaps = [] | |
| total_gap = 0 | |
| for _ in range(seq_len): | |
| gap = int(np.random.randint(self.rand_gap) + 1) | |
| total_gap += gap | |
| curr_gaps.append(gap) | |
| if total_gap <= max_gap: | |
| break | |
| return curr_gaps, total_gap | |
| def get_prev_index(self, lablist, total_gap): | |
| search_range = len(lablist) - total_gap | |
| if search_range > 1: | |
| prev_index = np.random.randint(search_range) | |
| else: | |
| prev_index = 0 | |
| return prev_index | |
| def check_index(self, total_len, index, allow_reflect=True): | |
| if total_len <= 1: | |
| return 0 | |
| if index < 0: | |
| if allow_reflect: | |
| index = -index | |
| index = self.check_index(total_len, index, True) | |
| else: | |
| index = 0 | |
| elif index >= total_len: | |
| if allow_reflect: | |
| index = 2 * (total_len - 1) - index | |
| index = self.check_index(total_len, index, True) | |
| else: | |
| index = total_len - 1 | |
| return index | |
| def get_curr_indices(self, lablist, prev_index, gaps): | |
| total_len = len(lablist) | |
| curr_indices = [] | |
| now_index = prev_index | |
| for gap in gaps: | |
| now_index += gap | |
| curr_indices.append(self.check_index(total_len, now_index)) | |
| return curr_indices | |
| def get_image_label(self, seqname, imagelist, lablist, index): | |
| image = cv2.imread( | |
| os.path.join(self.image_root, seqname, imagelist[index])) | |
| image = np.array(image, dtype=np.float32) | |
| if self.rgb: | |
| image = image[:, :, [2, 1, 0]] | |
| label = Image.open( | |
| os.path.join(self.label_root, seqname, lablist[index])) | |
| label = np.array(label, dtype=np.uint8) | |
| return image, label | |
| def sample_sequence(self, idx): | |
| idx = idx % len(self.seqs) | |
| seqname = self.seqs[idx] | |
| imagelist, lablist = self.imglistdic[seqname] | |
| frame_num = len(imagelist) | |
| if self.rand_reverse: | |
| imagelist, lablist = self.reverse_seq(imagelist, lablist) | |
| is_consistent = False | |
| max_try = 5 | |
| try_step = 0 | |
| while (is_consistent is False and try_step < max_try): | |
| try_step += 1 | |
| # generate random gaps | |
| curr_gaps, total_gap = self.get_curr_gaps(self.seq_len - 1) | |
| if self.enable_prev_frame: # prev frame is randomly sampled | |
| # get prev frame | |
| prev_index = self.get_prev_index(lablist, total_gap) | |
| prev_image, prev_label = self.get_image_label( | |
| seqname, imagelist, lablist, prev_index) | |
| prev_objs = list(np.unique(prev_label)) | |
| # get curr frames | |
| curr_indices = self.get_curr_indices(lablist, prev_index, | |
| curr_gaps) | |
| curr_images, curr_labels, curr_objs = [], [], [] | |
| for curr_index in curr_indices: | |
| curr_image, curr_label = self.get_image_label( | |
| seqname, imagelist, lablist, curr_index) | |
| c_objs = list(np.unique(curr_label)) | |
| curr_images.append(curr_image) | |
| curr_labels.append(curr_label) | |
| curr_objs.extend(c_objs) | |
| objs = list(np.unique(prev_objs + curr_objs)) | |
| start_index = prev_index | |
| end_index = max(curr_indices) | |
| # get ref frame | |
| _try_step = 0 | |
| ref_index = self.get_ref_index_v2(seqname, lablist) | |
| while (ref_index > start_index and ref_index <= end_index | |
| and _try_step < max_try): | |
| _try_step += 1 | |
| ref_index = self.get_ref_index_v2(seqname, lablist) | |
| ref_image, ref_label = self.get_image_label( | |
| seqname, imagelist, lablist, ref_index) | |
| ref_objs = list(np.unique(ref_label)) | |
| else: # prev frame is next to ref frame | |
| # get ref frame | |
| ref_index = self.get_ref_index_v2(seqname, lablist) | |
| ref_image, ref_label = self.get_image_label( | |
| seqname, imagelist, lablist, ref_index) | |
| ref_objs = list(np.unique(ref_label)) | |
| # get curr frames | |
| curr_indices = self.get_curr_indices(lablist, ref_index, | |
| curr_gaps) | |
| curr_images, curr_labels, curr_objs = [], [], [] | |
| for curr_index in curr_indices: | |
| curr_image, curr_label = self.get_image_label( | |
| seqname, imagelist, lablist, curr_index) | |
| c_objs = list(np.unique(curr_label)) | |
| curr_images.append(curr_image) | |
| curr_labels.append(curr_label) | |
| curr_objs.extend(c_objs) | |
| objs = list(np.unique(curr_objs)) | |
| prev_image, prev_label = curr_images[0], curr_labels[0] | |
| curr_images, curr_labels = curr_images[1:], curr_labels[1:] | |
| is_consistent = True | |
| for obj in objs: | |
| if obj == 0: | |
| continue | |
| if obj not in ref_objs: | |
| is_consistent = False | |
| break | |
| # get meta info | |
| obj_num = list(np.sort(ref_objs))[-1] | |
| sample = { | |
| 'ref_img': ref_image, | |
| 'prev_img': prev_image, | |
| 'curr_img': curr_images, | |
| 'ref_label': ref_label, | |
| 'prev_label': prev_label, | |
| 'curr_label': curr_labels | |
| } | |
| sample['meta'] = { | |
| 'seq_name': seqname, | |
| 'frame_num': frame_num, | |
| 'obj_num': obj_num | |
| } | |
| if self.transform is not None: | |
| sample = self.transform(sample) | |
| return sample | |
| def __getitem__(self, idx): | |
| sample1 = self.sample_sequence(idx) | |
| if self.dynamic_merge and (sample1['meta']['obj_num'] == 0 | |
| or random.random() < self.merge_prob): | |
| rand_idx = np.random.randint(len(self.seqs)) | |
| while (rand_idx == (idx % len(self.seqs))): | |
| rand_idx = np.random.randint(len(self.seqs)) | |
| sample2 = self.sample_sequence(rand_idx) | |
| sample = self.merge_sample(sample1, sample2) | |
| else: | |
| sample = sample1 | |
| return sample | |
| def merge_sample(self, sample1, sample2, min_obj_pixels=100): | |
| return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n) | |
| class DAVIS2017_Train(VOSTrain): | |
| def __init__(self, | |
| split=['train'], | |
| root='./DAVIS', | |
| transform=None, | |
| rgb=True, | |
| repeat_time=1, | |
| full_resolution=True, | |
| year=2017, | |
| rand_gap=3, | |
| seq_len=5, | |
| rand_reverse=True, | |
| dynamic_merge=True, | |
| enable_prev_frame=False, | |
| max_obj_n=10, | |
| merge_prob=0.3): | |
| if full_resolution: | |
| resolution = 'Full-Resolution' | |
| if not os.path.exists(os.path.join(root, 'JPEGImages', | |
| resolution)): | |
| print('No Full-Resolution, use 480p instead.') | |
| resolution = '480p' | |
| else: | |
| resolution = '480p' | |
| image_root = os.path.join(root, 'JPEGImages', resolution) | |
| label_root = os.path.join(root, 'Annotations', resolution) | |
| seq_names = [] | |
| for spt in split: | |
| with open(os.path.join(root, 'ImageSets', str(year), | |
| spt + '.txt')) as f: | |
| seqs_tmp = f.readlines() | |
| seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) | |
| seq_names.extend(seqs_tmp) | |
| imglistdic = {} | |
| for seq_name in seq_names: | |
| images = list( | |
| np.sort(os.listdir(os.path.join(image_root, seq_name)))) | |
| labels = list( | |
| np.sort(os.listdir(os.path.join(label_root, seq_name)))) | |
| imglistdic[seq_name] = (images, labels) | |
| super(DAVIS2017_Train, self).__init__(image_root, | |
| label_root, | |
| imglistdic, | |
| transform, | |
| rgb, | |
| repeat_time, | |
| rand_gap, | |
| seq_len, | |
| rand_reverse, | |
| dynamic_merge, | |
| enable_prev_frame, | |
| merge_prob=merge_prob, | |
| max_obj_n=max_obj_n) | |
| class YOUTUBEVOS_Train(VOSTrain): | |
| def __init__(self, | |
| root='./datasets/YTB', | |
| year=2019, | |
| transform=None, | |
| rgb=True, | |
| rand_gap=3, | |
| seq_len=3, | |
| rand_reverse=True, | |
| dynamic_merge=True, | |
| enable_prev_frame=False, | |
| max_obj_n=10, | |
| merge_prob=0.3): | |
| root = os.path.join(root, str(year), 'train') | |
| image_root = os.path.join(root, 'JPEGImages') | |
| label_root = os.path.join(root, 'Annotations') | |
| self.seq_list_file = os.path.join(root, 'meta.json') | |
| self._check_preprocess() | |
| seq_names = list(self.ann_f.keys()) | |
| imglistdic = {} | |
| for seq_name in seq_names: | |
| data = self.ann_f[seq_name]['objects'] | |
| obj_names = list(data.keys()) | |
| images = [] | |
| labels = [] | |
| for obj_n in obj_names: | |
| if len(data[obj_n]["frames"]) < 2: | |
| print("Short object: " + seq_name + '-' + obj_n) | |
| continue | |
| images += list( | |
| map(lambda x: x + '.jpg', list(data[obj_n]["frames"]))) | |
| labels += list( | |
| map(lambda x: x + '.png', list(data[obj_n]["frames"]))) | |
| images = np.sort(np.unique(images)) | |
| labels = np.sort(np.unique(labels)) | |
| if len(images) < 2: | |
| print("Short video: " + seq_name) | |
| continue | |
| imglistdic[seq_name] = (images, labels) | |
| super(YOUTUBEVOS_Train, self).__init__(image_root, | |
| label_root, | |
| imglistdic, | |
| transform, | |
| rgb, | |
| 1, | |
| rand_gap, | |
| seq_len, | |
| rand_reverse, | |
| dynamic_merge, | |
| enable_prev_frame, | |
| merge_prob=merge_prob, | |
| max_obj_n=max_obj_n) | |
| def _check_preprocess(self): | |
| if not os.path.isfile(self.seq_list_file): | |
| print('No such file: {}.'.format(self.seq_list_file)) | |
| return False | |
| else: | |
| self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] | |
| return True | |
| class TEST(Dataset): | |
| def __init__( | |
| self, | |
| seq_len=3, | |
| obj_num=3, | |
| transform=None, | |
| ): | |
| self.seq_len = seq_len | |
| self.obj_num = obj_num | |
| self.transform = transform | |
| def __len__(self): | |
| return 3000 | |
| def __getitem__(self, idx): | |
| img = np.zeros((800, 800, 3)).astype(np.float32) | |
| label = np.ones((800, 800)).astype(np.uint8) | |
| sample = { | |
| 'ref_img': img, | |
| 'prev_img': img, | |
| 'curr_img': [img] * (self.seq_len - 2), | |
| 'ref_label': label, | |
| 'prev_label': label, | |
| 'curr_label': [label] * (self.seq_len - 2) | |
| } | |
| sample['meta'] = { | |
| 'seq_name': 'test', | |
| 'frame_num': 100, | |
| 'obj_num': self.obj_num | |
| } | |
| if self.transform is not None: | |
| sample = self.transform(sample) | |
| return sample | |