Spaces:
Running
Running
File size: 5,181 Bytes
b7eedf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import csv
import os
import cv2
import math
import random
import json
import pickle
import os.path as osp
from .augmentation import RGBDAugmentor
from .rgbd_utils import *
class RGBDDataset(data.Dataset):
def __init__(self, name, datapath, n_frames=4, crop_size=[384,512], fmin=8.0, fmax=75.0, do_aug=True):
""" Base class for RGBD dataset """
self.aug = None
self.root = datapath
self.name = name
self.n_frames = n_frames
self.fmin = fmin # exclude very easy examples
self.fmax = fmax # exclude very hard examples
if do_aug:
self.aug = RGBDAugmentor(crop_size=crop_size)
# building dataset is expensive, cache so only needs to be performed once
cur_path = osp.dirname(osp.abspath(__file__))
if not os.path.isdir(osp.join(cur_path, 'cache')):
os.mkdir(osp.join(cur_path, 'cache'))
cache_path = osp.join(cur_path, 'cache', '{}.pickle'.format(self.name))
if osp.isfile(cache_path):
scene_info = pickle.load(open(cache_path, 'rb'))[0]
else:
scene_info = self._build_dataset()
with open(cache_path, 'wb') as cachefile:
pickle.dump((scene_info,), cachefile)
self.scene_info = scene_info
self._build_dataset_index()
def _build_dataset_index(self):
self.dataset_index = []
for scene in self.scene_info:
if not self.__class__.is_test_scene(scene):
graph = self.scene_info[scene]['graph']
for i in graph:
if len(graph[i][0]) > self.n_frames:
self.dataset_index.append((scene, i))
else:
print("Reserving {} for validation".format(scene))
@staticmethod
def image_read(image_file):
return cv2.imread(image_file)
@staticmethod
def depth_read(depth_file):
return np.load(depth_file)
def build_frame_graph(self, poses, depths, intrinsics, f=16, max_flow=256):
""" compute optical flow distance between all pairs of frames """
def read_disp(fn):
depth = self.__class__.depth_read(fn)[f//2::f, f//2::f]
depth[depth < 0.01] = np.mean(depth)
return 1.0 / depth
poses = np.array(poses)
intrinsics = np.array(intrinsics) / f
disps = np.stack(list(map(read_disp, depths)), 0)
d = f * compute_distance_matrix_flow(poses, disps, intrinsics)
# uncomment for nice visualization
# import matplotlib.pyplot as plt
# plt.imshow(d)
# plt.show()
graph = {}
for i in range(d.shape[0]):
j, = np.where(d[i] < max_flow)
graph[i] = (j, d[i,j])
return graph
def __getitem__(self, index):
""" return training video """
index = index % len(self.dataset_index)
scene_id, ix = self.dataset_index[index]
frame_graph = self.scene_info[scene_id]['graph']
images_list = self.scene_info[scene_id]['images']
depths_list = self.scene_info[scene_id]['depths']
poses_list = self.scene_info[scene_id]['poses']
intrinsics_list = self.scene_info[scene_id]['intrinsics']
inds = [ ix ]
while len(inds) < self.n_frames:
# get other frames within flow threshold
k = (frame_graph[ix][1] > self.fmin) & (frame_graph[ix][1] < self.fmax)
frames = frame_graph[ix][0][k]
# prefer frames forward in time
if np.count_nonzero(frames[frames > ix]):
ix = np.random.choice(frames[frames > ix])
elif np.count_nonzero(frames):
ix = np.random.choice(frames)
inds += [ ix ]
images, depths, poses, intrinsics = [], [], [], []
for i in inds:
images.append(self.__class__.image_read(images_list[i]))
depths.append(self.__class__.depth_read(depths_list[i]))
poses.append(poses_list[i])
intrinsics.append(intrinsics_list[i])
images = np.stack(images).astype(np.float32)
depths = np.stack(depths).astype(np.float32)
poses = np.stack(poses).astype(np.float32)
intrinsics = np.stack(intrinsics).astype(np.float32)
images = torch.from_numpy(images).float()
images = images.permute(0, 3, 1, 2)
disps = torch.from_numpy(1.0 / depths)
poses = torch.from_numpy(poses)
intrinsics = torch.from_numpy(intrinsics)
if self.aug is not None:
images, poses, disps, intrinsics = \
self.aug(images, poses, disps, intrinsics)
# scale scene
if len(disps[disps>0.01]) > 0:
s = disps[disps>0.01].mean()
disps = disps / s
poses[...,:3] *= s
return images, poses, disps, intrinsics
def __len__(self):
return len(self.dataset_index)
def __imul__(self, x):
self.dataset_index *= x
return self
|