Spaces:
Sleeping
Sleeping
| from __future__ import division | |
| import os | |
| import shutil | |
| import json | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| from utils.image import _palette | |
| class VOSTest(Dataset): | |
| def __init__(self, | |
| image_root, | |
| label_root, | |
| seq_name, | |
| images, | |
| labels, | |
| rgb=True, | |
| transform=None, | |
| single_obj=False, | |
| resolution=None): | |
| self.image_root = image_root | |
| self.label_root = label_root | |
| self.seq_name = seq_name | |
| self.images = images | |
| self.labels = labels | |
| self.obj_num = 1 | |
| self.num_frame = len(self.images) | |
| self.transform = transform | |
| self.rgb = rgb | |
| self.single_obj = single_obj | |
| self.resolution = resolution | |
| self.obj_nums = [] | |
| self.obj_indices = [] | |
| curr_objs = [0] | |
| for img_name in self.images: | |
| self.obj_nums.append(len(curr_objs) - 1) | |
| current_label_name = img_name.split('.')[0] + '.png' | |
| if current_label_name in self.labels: | |
| current_label = self.read_label(current_label_name) | |
| curr_obj = list(np.unique(current_label)) | |
| for obj_idx in curr_obj: | |
| if obj_idx not in curr_objs: | |
| curr_objs.append(obj_idx) | |
| self.obj_indices.append(curr_objs.copy()) | |
| self.obj_nums[0] = self.obj_nums[1] | |
| def __len__(self): | |
| return len(self.images) | |
| def read_image(self, idx): | |
| img_name = self.images[idx] | |
| img_path = os.path.join(self.image_root, self.seq_name, img_name) | |
| img = cv2.imread(img_path) | |
| img = np.array(img, dtype=np.float32) | |
| if self.rgb: | |
| img = img[:, :, [2, 1, 0]] | |
| return img | |
| def read_label(self, label_name, squeeze_idx=None): | |
| label_path = os.path.join(self.label_root, self.seq_name, label_name) | |
| label = Image.open(label_path) | |
| label = np.array(label, dtype=np.uint8) | |
| if self.single_obj: | |
| label = (label > 0).astype(np.uint8) | |
| elif squeeze_idx is not None: | |
| squeezed_label = label * 0 | |
| for idx in range(len(squeeze_idx)): | |
| obj_id = squeeze_idx[idx] | |
| if obj_id == 0: | |
| continue | |
| mask = label == obj_id | |
| squeezed_label += (mask * idx).astype(np.uint8) | |
| label = squeezed_label | |
| return label | |
| def __getitem__(self, idx): | |
| img_name = self.images[idx] | |
| current_img = self.read_image(idx) | |
| height, width, channels = current_img.shape | |
| if self.resolution is not None: | |
| width = int(np.ceil( | |
| float(width) * self.resolution / float(height))) | |
| height = int(self.resolution) | |
| current_label_name = img_name.split('.')[0] + '.png' | |
| obj_num = self.obj_nums[idx] | |
| obj_idx = self.obj_indices[idx] | |
| if current_label_name in self.labels: | |
| current_label = self.read_label(current_label_name, obj_idx) | |
| sample = { | |
| 'current_img': current_img, | |
| 'current_label': current_label | |
| } | |
| else: | |
| sample = {'current_img': current_img} | |
| sample['meta'] = { | |
| 'seq_name': self.seq_name, | |
| 'frame_num': self.num_frame, | |
| 'obj_num': obj_num, | |
| 'current_name': img_name, | |
| 'height': height, | |
| 'width': width, | |
| 'flip': False, | |
| 'obj_idx': obj_idx | |
| } | |
| if self.transform is not None: | |
| sample = self.transform(sample) | |
| return sample | |
| class YOUTUBEVOS_Test(object): | |
| def __init__(self, | |
| root='./datasets/YTB', | |
| year=2018, | |
| split='val', | |
| transform=None, | |
| rgb=True, | |
| result_root=None): | |
| if split == 'val': | |
| split = 'valid' | |
| root = os.path.join(root, str(year), split) | |
| self.db_root_dir = root | |
| self.result_root = result_root | |
| self.rgb = rgb | |
| self.transform = transform | |
| self.seq_list_file = os.path.join(self.db_root_dir, 'meta.json') | |
| self._check_preprocess() | |
| self.seqs = list(self.ann_f.keys()) | |
| self.image_root = os.path.join(root, 'JPEGImages') | |
| self.label_root = os.path.join(root, 'Annotations') | |
| def __len__(self): | |
| return len(self.seqs) | |
| def __getitem__(self, idx): | |
| seq_name = self.seqs[idx] | |
| data = self.ann_f[seq_name]['objects'] | |
| obj_names = list(data.keys()) | |
| images = [] | |
| labels = [] | |
| for obj_n in obj_names: | |
| images += map(lambda x: x + '.jpg', list(data[obj_n]["frames"])) | |
| labels.append(data[obj_n]["frames"][0] + '.png') | |
| images = np.sort(np.unique(images)) | |
| labels = np.sort(np.unique(labels)) | |
| try: | |
| if not os.path.isfile( | |
| os.path.join(self.result_root, seq_name, labels[0])): | |
| if not os.path.exists(os.path.join(self.result_root, | |
| seq_name)): | |
| os.makedirs(os.path.join(self.result_root, seq_name)) | |
| shutil.copy( | |
| os.path.join(self.label_root, seq_name, labels[0]), | |
| os.path.join(self.result_root, seq_name, labels[0])) | |
| except Exception as inst: | |
| print(inst) | |
| print('Failed to create a result folder for sequence {}.'.format( | |
| seq_name)) | |
| seq_dataset = VOSTest(self.image_root, | |
| self.label_root, | |
| seq_name, | |
| images, | |
| labels, | |
| transform=self.transform, | |
| rgb=self.rgb) | |
| return seq_dataset | |
| def _check_preprocess(self): | |
| _seq_list_file = self.seq_list_file | |
| if not os.path.isfile(_seq_list_file): | |
| print(_seq_list_file) | |
| return False | |
| else: | |
| self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] | |
| return True | |
| class YOUTUBEVOS_DenseTest(object): | |
| def __init__(self, | |
| root='./datasets/YTB', | |
| year=2018, | |
| split='val', | |
| transform=None, | |
| rgb=True, | |
| result_root=None): | |
| if split == 'val': | |
| split = 'valid' | |
| root_sparse = os.path.join(root, str(year), split) | |
| root_dense = root_sparse + '_all_frames' | |
| self.db_root_dir = root_dense | |
| self.result_root = result_root | |
| self.rgb = rgb | |
| self.transform = transform | |
| self.seq_list_file = os.path.join(root_sparse, 'meta.json') | |
| self._check_preprocess() | |
| self.seqs = list(self.ann_f.keys()) | |
| self.image_root = os.path.join(root_dense, 'JPEGImages') | |
| self.label_root = os.path.join(root_sparse, 'Annotations') | |
| def __len__(self): | |
| return len(self.seqs) | |
| def __getitem__(self, idx): | |
| seq_name = self.seqs[idx] | |
| data = self.ann_f[seq_name]['objects'] | |
| obj_names = list(data.keys()) | |
| images_sparse = [] | |
| for obj_n in obj_names: | |
| images_sparse += map(lambda x: x + '.jpg', | |
| list(data[obj_n]["frames"])) | |
| images_sparse = np.sort(np.unique(images_sparse)) | |
| images = np.sort( | |
| list(os.listdir(os.path.join(self.image_root, seq_name)))) | |
| start_img = images_sparse[0] | |
| end_img = images_sparse[-1] | |
| for start_idx in range(len(images)): | |
| if start_img in images[start_idx]: | |
| break | |
| for end_idx in range(len(images))[::-1]: | |
| if end_img in images[end_idx]: | |
| break | |
| images = images[start_idx:(end_idx + 1)] | |
| labels = np.sort( | |
| list(os.listdir(os.path.join(self.label_root, seq_name)))) | |
| try: | |
| if not os.path.isfile( | |
| os.path.join(self.result_root, seq_name, labels[0])): | |
| if not os.path.exists(os.path.join(self.result_root, | |
| seq_name)): | |
| os.makedirs(os.path.join(self.result_root, seq_name)) | |
| shutil.copy( | |
| os.path.join(self.label_root, seq_name, labels[0]), | |
| os.path.join(self.result_root, seq_name, labels[0])) | |
| except Exception as inst: | |
| print(inst) | |
| print('Failed to create a result folder for sequence {}.'.format( | |
| seq_name)) | |
| seq_dataset = VOSTest(self.image_root, | |
| self.label_root, | |
| seq_name, | |
| images, | |
| labels, | |
| transform=self.transform, | |
| rgb=self.rgb) | |
| seq_dataset.images_sparse = images_sparse | |
| return seq_dataset | |
| def _check_preprocess(self): | |
| _seq_list_file = self.seq_list_file | |
| if not os.path.isfile(_seq_list_file): | |
| print(_seq_list_file) | |
| return False | |
| else: | |
| self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] | |
| return True | |
| class DAVIS_Test(object): | |
| def __init__(self, | |
| split=['val'], | |
| root='./DAVIS', | |
| year=2017, | |
| transform=None, | |
| rgb=True, | |
| full_resolution=False, | |
| result_root=None): | |
| self.transform = transform | |
| self.rgb = rgb | |
| self.result_root = result_root | |
| if year == 2016: | |
| self.single_obj = True | |
| else: | |
| self.single_obj = False | |
| if full_resolution: | |
| resolution = 'Full-Resolution' | |
| else: | |
| resolution = '480p' | |
| self.image_root = os.path.join(root, 'JPEGImages', resolution) | |
| self.label_root = os.path.join(root, 'Annotations', resolution) | |
| seq_names = [] | |
| for spt in split: | |
| if spt == 'test': | |
| spt = 'test-dev' | |
| with open(os.path.join(root, 'ImageSets', str(year), | |
| spt + '.txt')) as f: | |
| seqs_tmp = f.readlines() | |
| seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) | |
| seq_names.extend(seqs_tmp) | |
| self.seqs = list(np.unique(seq_names)) | |
| def __len__(self): | |
| return len(self.seqs) | |
| def __getitem__(self, idx): | |
| seq_name = self.seqs[idx] | |
| images = list( | |
| np.sort(os.listdir(os.path.join(self.image_root, seq_name)))) | |
| labels = [images[0].replace('jpg', 'png')] | |
| if not os.path.isfile( | |
| os.path.join(self.result_root, seq_name, labels[0])): | |
| seq_result_folder = os.path.join(self.result_root, seq_name) | |
| try: | |
| if not os.path.exists(seq_result_folder): | |
| os.makedirs(seq_result_folder) | |
| except Exception as inst: | |
| print(inst) | |
| print( | |
| 'Failed to create a result folder for sequence {}.'.format( | |
| seq_name)) | |
| source_label_path = os.path.join(self.label_root, seq_name, | |
| labels[0]) | |
| result_label_path = os.path.join(self.result_root, seq_name, | |
| labels[0]) | |
| if self.single_obj: | |
| label = Image.open(source_label_path) | |
| label = np.array(label, dtype=np.uint8) | |
| label = (label > 0).astype(np.uint8) | |
| label = Image.fromarray(label).convert('P') | |
| label.putpalette(_palette) | |
| label.save(result_label_path) | |
| else: | |
| shutil.copy(source_label_path, result_label_path) | |
| seq_dataset = VOSTest(self.image_root, | |
| self.label_root, | |
| seq_name, | |
| images, | |
| labels, | |
| transform=self.transform, | |
| rgb=self.rgb, | |
| single_obj=self.single_obj, | |
| resolution=480) | |
| return seq_dataset | |
| class _EVAL_TEST(Dataset): | |
| def __init__(self, transform, seq_name): | |
| self.seq_name = seq_name | |
| self.num_frame = 10 | |
| self.transform = transform | |
| def __len__(self): | |
| return self.num_frame | |
| def __getitem__(self, idx): | |
| current_frame_obj_num = 2 | |
| height = 400 | |
| width = 400 | |
| img_name = 'test{}.jpg'.format(idx) | |
| current_img = np.zeros((height, width, 3)).astype(np.float32) | |
| if idx == 0: | |
| current_label = (current_frame_obj_num * np.ones( | |
| (height, width))).astype(np.uint8) | |
| sample = { | |
| 'current_img': current_img, | |
| 'current_label': current_label | |
| } | |
| else: | |
| sample = {'current_img': current_img} | |
| sample['meta'] = { | |
| 'seq_name': self.seq_name, | |
| 'frame_num': self.num_frame, | |
| 'obj_num': current_frame_obj_num, | |
| 'current_name': img_name, | |
| 'height': height, | |
| 'width': width, | |
| 'flip': False | |
| } | |
| if self.transform is not None: | |
| sample = self.transform(sample) | |
| return sample | |
| class EVAL_TEST(object): | |
| def __init__(self, transform=None, result_root=None): | |
| self.transform = transform | |
| self.result_root = result_root | |
| self.seqs = ['test1', 'test2', 'test3'] | |
| def __len__(self): | |
| return len(self.seqs) | |
| def __getitem__(self, idx): | |
| seq_name = self.seqs[idx] | |
| if not os.path.exists(os.path.join(self.result_root, seq_name)): | |
| os.makedirs(os.path.join(self.result_root, seq_name)) | |
| seq_dataset = _EVAL_TEST(self.transform, seq_name) | |
| return seq_dataset | |